mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
feat(auth): authentication module with multi-tenant isolation (RFC-001)
Introduce an always-on auth layer with auto-created admin on first boot, multi-tenant isolation for threads/stores, and a full setup/login flow. Backend - JWT access tokens with `ver` field for stale-token rejection; bump on password/email change - Password hashing, HttpOnly+Secure cookies (Secure derived from request scheme at runtime) - CSRF middleware covering both REST and LangGraph routes - IP-based login rate limiting (5 attempts / 5-min lockout) with bounded dict growth and X-Forwarded-For bypass fix - Multi-worker-safe admin auto-creation (single DB write, WAL once) - needs_setup + token_version on User model; SQLite schema migration - Thread/store isolation by owner; orphan thread migration on first admin registration - thread_id validated as UUID to prevent log injection - CLI tool to reset admin password - Decorator-based authz module extracted from auth core Frontend - Login and setup pages with SSR guard for needs_setup flow - Account settings page (change password / email) - AuthProvider + route guards; skips redirect when no users registered - i18n (en-US / zh-CN) for auth surfaces - Typed auth API client; parseAuthError unwraps FastAPI detail envelope Infra & tooling - Unified `serve.sh` with gateway mode + auto dep install - Public PyPI uv.toml pin for CI compatibility - Regenerated uv.lock with public index Tests - HTTP vs HTTPS cookie security tests - Auth middleware, rate limiter, CSRF, setup flow coverage
This commit is contained in:
@@ -0,0 +1,506 @@
|
||||
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
|
||||
|
||||
from datetime import timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import (
|
||||
AuthContext,
|
||||
Permissions,
|
||||
get_auth_context,
|
||||
require_auth,
|
||||
require_permission,
|
||||
)
|
||||
|
||||
# ── Password Hashing ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_hash_password_and_verify():
|
||||
"""Hashing and verification round-trip."""
|
||||
password = "s3cr3tP@ssw0rd!"
|
||||
hashed = hash_password(password)
|
||||
assert hashed != password
|
||||
assert verify_password(password, hashed) is True
|
||||
assert verify_password("wrongpassword", hashed) is False
|
||||
|
||||
|
||||
def test_hash_password_different_each_time():
|
||||
"""bcrypt generates unique salts, so same password has different hashes."""
|
||||
password = "testpassword"
|
||||
h1 = hash_password(password)
|
||||
h2 = hash_password(password)
|
||||
assert h1 != h2 # Different salts
|
||||
# But both verify correctly
|
||||
assert verify_password(password, h1) is True
|
||||
assert verify_password(password, h2) is True
|
||||
|
||||
|
||||
def test_verify_password_rejects_empty():
|
||||
"""Empty password should not verify."""
|
||||
hashed = hash_password("nonempty")
|
||||
assert verify_password("", hashed) is False
|
||||
|
||||
|
||||
# ── JWT ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_create_and_decode_token():
|
||||
"""JWT creation and decoding round-trip."""
|
||||
user_id = str(uuid4())
|
||||
# Set a valid JWT secret for this test
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(user_id)
|
||||
assert isinstance(token, str)
|
||||
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
def test_decode_token_expired():
|
||||
"""Expired token returns TokenError.EXPIRED."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
user_id = str(uuid4())
|
||||
# Create token that expires immediately
|
||||
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
|
||||
payload = decode_token(token)
|
||||
assert payload == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid():
|
||||
"""Invalid token returns TokenError."""
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert isinstance(decode_token("not.a.valid.token"), TokenError)
|
||||
assert isinstance(decode_token(""), TokenError)
|
||||
assert isinstance(decode_token("completely-wrong"), TokenError)
|
||||
|
||||
|
||||
def test_create_token_custom_expiry():
|
||||
"""Custom expiry is respected."""
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
|
||||
payload = decode_token(token)
|
||||
assert payload is not None
|
||||
assert payload.sub == user_id
|
||||
|
||||
|
||||
# ── AuthContext ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_context_unauthenticated():
|
||||
"""AuthContext with no user."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
assert ctx.is_authenticated is False
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_authenticated_no_perms():
|
||||
"""AuthContext with user but no permissions."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
assert ctx.is_authenticated is True
|
||||
assert ctx.has_permission("threads", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_has_permission():
|
||||
"""AuthContext permission checking."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
|
||||
ctx = AuthContext(user=user, permissions=perms)
|
||||
assert ctx.has_permission("threads", "read") is True
|
||||
assert ctx.has_permission("threads", "write") is True
|
||||
assert ctx.has_permission("threads", "delete") is False
|
||||
assert ctx.has_permission("runs", "read") is False
|
||||
|
||||
|
||||
def test_auth_context_require_user_raises():
|
||||
"""require_user raises 401 when not authenticated."""
|
||||
ctx = AuthContext(user=None, permissions=[])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
ctx.require_user()
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
|
||||
def test_auth_context_require_user_returns_user():
|
||||
"""require_user returns user when authenticated."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[])
|
||||
returned = ctx.require_user()
|
||||
assert returned == user
|
||||
|
||||
|
||||
# ── get_auth_context helper ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_auth_context_not_set():
|
||||
"""get_auth_context returns None when auth not set on request."""
|
||||
mock_request = MagicMock()
|
||||
# Make getattr return None (simulating attribute not set)
|
||||
mock_request.state = MagicMock()
|
||||
del mock_request.state.auth
|
||||
assert get_auth_context(mock_request) is None
|
||||
|
||||
|
||||
def test_get_auth_context_set():
|
||||
"""get_auth_context returns the AuthContext from request."""
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.state.auth = ctx
|
||||
|
||||
assert get_auth_context(mock_request) == ctx
|
||||
|
||||
|
||||
# ── require_auth decorator ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_auth_sets_auth_context():
|
||||
"""require_auth sets auth context on request from cookie."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_auth
|
||||
async def endpoint(request: Request):
|
||||
ctx = get_auth_context(request)
|
||||
return {"authenticated": ctx.is_authenticated}
|
||||
|
||||
with TestClient(app) as client:
|
||||
# No cookie → anonymous
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 200
|
||||
assert response.json()["authenticated"] is False
|
||||
|
||||
|
||||
def test_require_auth_requires_request_param():
|
||||
"""require_auth raises ValueError if request parameter is missing."""
|
||||
import asyncio
|
||||
|
||||
@require_auth
|
||||
async def bad_endpoint(): # Missing `request` parameter
|
||||
pass
|
||||
|
||||
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
|
||||
asyncio.run(bad_endpoint())
|
||||
|
||||
|
||||
# ── require_permission decorator ─────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_require_permission_requires_auth():
|
||||
"""require_permission raises 401 when not authenticated."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "read")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Authentication required" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_require_permission_denies_wrong_permission():
|
||||
"""User without required permission gets 403."""
|
||||
from fastapi import Request
|
||||
|
||||
app = FastAPI()
|
||||
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
|
||||
@app.get("/test")
|
||||
@require_permission("threads", "delete")
|
||||
async def endpoint(request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/test")
|
||||
assert response.status_code == 403
|
||||
assert "Permission denied" in response.json()["detail"]
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── User Model Fields ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_model_has_needs_setup_default_false():
|
||||
"""New users default to needs_setup=False."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.needs_setup is False
|
||||
|
||||
|
||||
def test_user_model_has_token_version_default_zero():
|
||||
"""New users default to token_version=0."""
|
||||
user = User(email="test@example.com", password_hash="hash")
|
||||
assert user.token_version == 0
|
||||
|
||||
|
||||
def test_user_model_needs_setup_true():
|
||||
"""Auto-created admin has needs_setup=True."""
|
||||
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
|
||||
assert user.needs_setup is True
|
||||
|
||||
|
||||
def test_sqlite_round_trip_new_fields():
|
||||
"""needs_setup and token_version survive create → read round-trip."""
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.gateway.auth.repositories import sqlite as sqlite_mod
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_users.db")
|
||||
old_path = sqlite_mod._resolved_db_path
|
||||
old_init = sqlite_mod._table_initialized
|
||||
sqlite_mod._resolved_db_path = Path(db_path)
|
||||
sqlite_mod._table_initialized = False
|
||||
try:
|
||||
repo = sqlite_mod.SQLiteUserRepository()
|
||||
user = User(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
created = asyncio.run(repo.create_user(user))
|
||||
assert created.needs_setup is True
|
||||
assert created.token_version == 3
|
||||
|
||||
fetched = asyncio.run(repo.get_user_by_email("setup@test.com"))
|
||||
assert fetched is not None
|
||||
assert fetched.needs_setup is True
|
||||
assert fetched.token_version == 3
|
||||
|
||||
fetched.needs_setup = False
|
||||
fetched.token_version = 4
|
||||
asyncio.run(repo.update_user(fetched))
|
||||
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id)))
|
||||
assert refetched.needs_setup is False
|
||||
assert refetched.token_version == 4
|
||||
finally:
|
||||
sqlite_mod._resolved_db_path = old_path
|
||||
sqlite_mod._table_initialized = old_init
|
||||
|
||||
|
||||
# ── Token Versioning ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_jwt_encodes_ver():
|
||||
"""JWT payload includes ver field."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()), token_version=3)
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_jwt_default_ver_zero():
|
||||
"""JWT ver defaults to 0."""
|
||||
import os
|
||||
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
token = create_access_token(str(uuid4()))
|
||||
payload = decode_token(token)
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.ver == 0
|
||||
|
||||
|
||||
def test_token_version_mismatch_rejects():
|
||||
"""Token with stale ver is rejected by get_current_user_from_request."""
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||
|
||||
user_id = str(uuid4())
|
||||
token = create_access_token(user_id, token_version=0)
|
||||
|
||||
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.cookies = {"access_token": token}
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.get_user = AsyncMock(return_value=mock_user)
|
||||
mock_provider_fn.return_value = mock_provider
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
assert "revoked" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
# ── change-password extension ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_change_password_request_accepts_new_email():
|
||||
"""ChangePasswordRequest model accepts optional new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(
|
||||
current_password="old",
|
||||
new_password="newpassword",
|
||||
new_email="new@example.com",
|
||||
)
|
||||
assert req.new_email == "new@example.com"
|
||||
|
||||
|
||||
def test_change_password_request_new_email_optional():
|
||||
"""ChangePasswordRequest model works without new_email."""
|
||||
from app.gateway.routers.auth import ChangePasswordRequest
|
||||
|
||||
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
|
||||
assert req.new_email is None
|
||||
|
||||
|
||||
def test_login_response_includes_needs_setup():
|
||||
"""LoginResponse includes needs_setup field."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=3600, needs_setup=True)
|
||||
assert resp.needs_setup is True
|
||||
resp2 = LoginResponse(expires_in=3600)
|
||||
assert resp2.needs_setup is False
|
||||
|
||||
|
||||
# ── Rate Limiting ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_rate_limiter_allows_under_limit():
|
||||
"""Requests under the limit are allowed."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
|
||||
|
||||
_login_attempts.clear()
|
||||
_check_rate_limit("192.168.1.1") # Should not raise
|
||||
|
||||
|
||||
def test_rate_limiter_blocks_after_max_failures():
|
||||
"""IP is blocked after 5 consecutive failures."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.1"
|
||||
for _ in range(5):
|
||||
_record_login_failure(ip)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
_check_rate_limit(ip)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
|
||||
def test_rate_limiter_resets_on_success():
|
||||
"""Successful login clears the failure counter."""
|
||||
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
|
||||
|
||||
_login_attempts.clear()
|
||||
ip = "10.0.0.2"
|
||||
for _ in range(4):
|
||||
_record_login_failure(ip)
|
||||
_record_login_success(ip)
|
||||
_check_rate_limit(ip) # Should not raise
|
||||
|
||||
|
||||
# ── Client IP extraction ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_get_client_ip_direct_connection():
|
||||
"""Without nginx (no X-Real-IP), falls back to request.client.host."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.42"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_uses_x_real_ip():
|
||||
"""X-Real-IP (set by nginx) is used when present."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1" # uvicorn may have replaced this with XFF[0]
|
||||
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||
assert _get_client_ip(req) == "203.0.113.42"
|
||||
|
||||
|
||||
def test_get_client_ip_xff_ignored():
|
||||
"""X-Forwarded-For is never used; only X-Real-IP matters."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "10.0.0.1"
|
||||
req.headers = {"x-forwarded-for": "10.0.0.1, 198.51.100.5", "x-real-ip": "198.51.100.5"}
|
||||
assert _get_client_ip(req) == "198.51.100.5"
|
||||
|
||||
|
||||
def test_get_client_ip_no_real_ip_fallback():
|
||||
"""No X-Real-IP → falls back to client.host (direct connection)."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "127.0.0.1"
|
||||
req.headers = {}
|
||||
assert _get_client_ip(req) == "127.0.0.1"
|
||||
|
||||
|
||||
def test_get_client_ip_x_real_ip_always_preferred():
|
||||
"""X-Real-IP is always preferred over client.host regardless of IP."""
|
||||
from app.gateway.routers.auth import _get_client_ip
|
||||
|
||||
req = MagicMock()
|
||||
req.client.host = "203.0.113.99"
|
||||
req.headers = {"x-real-ip": "198.51.100.7"}
|
||||
assert _get_client_ip(req) == "198.51.100.7"
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
|
||||
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as config_module
|
||||
|
||||
config_module._auth_config = None
|
||||
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = config_module.get_auth_config()
|
||||
|
||||
assert config.jwt_secret # non-empty ephemeral secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
|
||||
# Cleanup
|
||||
config_module._auth_config = None
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Tests for AuthConfig typed configuration."""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.gateway.auth.config import AuthConfig
|
||||
|
||||
|
||||
def test_auth_config_defaults():
|
||||
config = AuthConfig(jwt_secret="test-secret-key-123")
|
||||
assert config.token_expiry_days == 7
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_range():
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_from_env():
|
||||
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret == "test-jwt-secret-from-env"
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_missing_secret_generates_ephemeral(caplog):
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Tests for auth error types and typed decode_token."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
|
||||
|
||||
def test_auth_error_code_values():
|
||||
assert AuthErrorCode.INVALID_CREDENTIALS == "invalid_credentials"
|
||||
assert AuthErrorCode.TOKEN_EXPIRED == "token_expired"
|
||||
assert AuthErrorCode.NOT_AUTHENTICATED == "not_authenticated"
|
||||
|
||||
|
||||
def test_token_error_values():
|
||||
assert TokenError.EXPIRED == "expired"
|
||||
assert TokenError.INVALID_SIGNATURE == "invalid_signature"
|
||||
assert TokenError.MALFORMED == "malformed"
|
||||
|
||||
|
||||
def test_auth_error_response_serialization():
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.TOKEN_EXPIRED,
|
||||
message="Token has expired",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert d == {"code": "token_expired", "message": "Token has expired"}
|
||||
|
||||
|
||||
def test_auth_error_response_from_dict():
|
||||
d = {"code": "invalid_credentials", "message": "Wrong password"}
|
||||
err = AuthErrorResponse(**d)
|
||||
assert err.code == AuthErrorCode.INVALID_CREDENTIALS
|
||||
|
||||
|
||||
# ── decode_token typed failure tests ──────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-jwt-decode-token-tests"
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_expired():
|
||||
_setup_config()
|
||||
expired_payload = {"sub": "user-1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired_payload, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_bad_signature():
|
||||
_setup_config()
|
||||
payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
|
||||
def test_decode_token_returns_token_error_on_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("not-a-jwt")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_returns_payload_on_valid():
|
||||
_setup_config()
|
||||
token = create_access_token("user-123")
|
||||
result = decode_token(token)
|
||||
assert not isinstance(result, TokenError)
|
||||
assert result.sub == "user-123"
|
||||
@@ -0,0 +1,216 @@
|
||||
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
|
||||
|
||||
import pytest
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||
|
||||
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/health",
|
||||
"/health/",
|
||||
"/docs",
|
||||
"/docs/",
|
||||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
],
|
||||
)
|
||||
def test_public_paths(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models",
|
||||
"/api/mcp/config",
|
||||
"/api/memory",
|
||||
"/api/skills",
|
||||
"/api/threads/123",
|
||||
"/api/threads/123/uploads",
|
||||
"/api/agents",
|
||||
"/api/channels",
|
||||
"/api/runs/stream",
|
||||
"/api/threads/123/runs",
|
||||
"/api/v1/auth/me",
|
||||
"/api/v1/auth/change-password",
|
||||
],
|
||||
)
|
||||
def test_protected_paths(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
# ── Trailing slash / normalization edge cases ─────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/v1/auth/login/local/",
|
||||
"/api/v1/auth/register/",
|
||||
"/api/v1/auth/logout/",
|
||||
"/api/v1/auth/setup-status/",
|
||||
],
|
||||
)
|
||||
def test_public_auth_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is True
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"path",
|
||||
[
|
||||
"/api/models/",
|
||||
"/api/v1/auth/me/",
|
||||
"/api/v1/auth/change-password/",
|
||||
],
|
||||
)
|
||||
def test_protected_paths_with_trailing_slash(path: str):
|
||||
assert _is_public(path) is False
|
||||
|
||||
|
||||
def test_unknown_api_path_is_protected():
|
||||
"""Fail-closed: any new /api/* path is protected by default."""
|
||||
assert _is_public("/api/new-feature") is False
|
||||
assert _is_public("/api/v2/something") is False
|
||||
assert _is_public("/api/v1/auth/new-endpoint") is False
|
||||
|
||||
|
||||
# ── Middleware integration tests ──────────────────────────────────────────
|
||||
|
||||
|
||||
def _make_app():
|
||||
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||
from fastapi import FastAPI
|
||||
|
||||
app = FastAPI()
|
||||
app.add_middleware(AuthMiddleware)
|
||||
|
||||
@app.get("/health")
|
||||
async def health():
|
||||
return {"status": "ok"}
|
||||
|
||||
@app.get("/api/v1/auth/me")
|
||||
async def auth_me():
|
||||
return {"id": "1", "email": "test@test.com"}
|
||||
|
||||
@app.get("/api/v1/auth/setup-status")
|
||||
async def setup_status():
|
||||
return {"needs_setup": False}
|
||||
|
||||
@app.get("/api/models")
|
||||
async def models_get():
|
||||
return {"models": []}
|
||||
|
||||
@app.put("/api/mcp/config")
|
||||
async def mcp_put():
|
||||
return {"ok": True}
|
||||
|
||||
@app.delete("/api/threads/abc")
|
||||
async def thread_delete():
|
||||
return {"ok": True}
|
||||
|
||||
@app.patch("/api/threads/abc")
|
||||
async def thread_patch():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/threads/abc/runs/stream")
|
||||
async def stream():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/future-endpoint")
|
||||
async def future():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
return TestClient(_make_app())
|
||||
|
||||
|
||||
def test_public_path_no_cookie(client):
|
||||
res = client.get("/health")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_public_auth_path_no_cookie(client):
|
||||
"""Public auth endpoints (login/register) pass without cookie."""
|
||||
res = client.get("/api/v1/auth/setup-status")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_auth_path_no_cookie(client):
|
||||
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
|
||||
res = client.get("/api/v1/auth/me")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_path_no_cookie_returns_401(client):
|
||||
res = client.get("/api/models")
|
||||
assert res.status_code == 401
|
||||
body = res.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_protected_path_with_cookie_passes(client):
|
||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_protected_post_no_cookie_returns_401(client):
|
||||
res = client.post("/api/threads/abc/runs/stream")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||
|
||||
|
||||
def test_protected_put_no_cookie(client):
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_delete_no_cookie(client):
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_protected_patch_no_cookie(client):
|
||||
res = client.patch("/api/threads/abc")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_put_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.put("/api/mcp/config")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
def test_delete_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.delete("/api/threads/abc")
|
||||
assert res.status_code == 200
|
||||
|
||||
|
||||
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
|
||||
|
||||
|
||||
def test_unknown_endpoint_no_cookie_returns_401(client):
|
||||
"""Any new /api/* endpoint is blocked by default without cookie."""
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_unknown_endpoint_with_cookie_passes(client):
|
||||
client.cookies.set("access_token", "tok")
|
||||
res = client.get("/api/future-endpoint")
|
||||
assert res.status_code == 200
|
||||
@@ -0,0 +1,675 @@
|
||||
"""Tests for auth type system hardening.
|
||||
|
||||
Covers structured error responses, typed decode_token callers,
|
||||
CSRF middleware path matching, config-driven cookie security,
|
||||
and unhappy paths / edge cases for all auth boundaries.
|
||||
"""
|
||||
|
||||
import os
|
||||
import secrets
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import patch
|
||||
|
||||
import jwt as pyjwt
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from pydantic import ValidationError
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import decode_token
|
||||
from app.gateway.csrf_middleware import (
|
||||
CSRF_COOKIE_NAME,
|
||||
CSRF_HEADER_NAME,
|
||||
CSRFMiddleware,
|
||||
is_auth_endpoint,
|
||||
should_check_csrf,
|
||||
)
|
||||
|
||||
# ── Setup ────────────────────────────────────────────────────────────
|
||||
|
||||
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
# ── CSRF Middleware Path Matching ────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeRequest:
|
||||
"""Minimal request mock for CSRF path matching tests."""
|
||||
|
||||
def __init__(self, path: str, method: str = "POST"):
|
||||
self.method = method
|
||||
|
||||
class _URL:
|
||||
def __init__(self, p):
|
||||
self.path = p
|
||||
|
||||
self.url = _URL(path)
|
||||
self.cookies = {}
|
||||
self.headers = {}
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local():
|
||||
"""login/local (actual route) should be exempt from CSRF."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_login_local_trailing_slash():
|
||||
"""Trailing slash should also be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login/local/")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_logout():
|
||||
req = _FakeRequest("/api/v1/auth/logout")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_exempts_register():
|
||||
req = _FakeRequest("/api/v1/auth/register")
|
||||
assert is_auth_endpoint(req) is True
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_old_login_path():
|
||||
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
|
||||
req = _FakeRequest("/api/v1/auth/login")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_does_not_exempt_me():
|
||||
req = _FakeRequest("/api/v1/auth/me")
|
||||
assert is_auth_endpoint(req) is False
|
||||
|
||||
|
||||
def test_csrf_skips_get_requests():
|
||||
req = _FakeRequest("/api/v1/auth/me", method="GET")
|
||||
assert should_check_csrf(req) is False
|
||||
|
||||
|
||||
def test_csrf_checks_post_to_protected():
|
||||
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
|
||||
assert should_check_csrf(req) is True
|
||||
|
||||
|
||||
# ── Structured Error Response Format ────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_error_response_has_code_and_message():
|
||||
"""All auth errors should have structured {code, message} format."""
|
||||
err = AuthErrorResponse(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Wrong password",
|
||||
)
|
||||
d = err.model_dump()
|
||||
assert "code" in d
|
||||
assert "message" in d
|
||||
assert d["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_auth_error_response_all_codes_serializable():
|
||||
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
|
||||
for code in AuthErrorCode:
|
||||
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
|
||||
d = err.model_dump()
|
||||
assert d["code"] == code.value
|
||||
|
||||
|
||||
# ── decode_token Caller Pattern ──────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_expired_maps_to_token_expired_code():
|
||||
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.EXPIRED
|
||||
|
||||
# Verify the mapping pattern used in route handlers
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_EXPIRED
|
||||
|
||||
|
||||
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
|
||||
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt as pyjwt
|
||||
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
result = decode_token(token)
|
||||
assert result == TokenError.INVALID_SIGNATURE
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
def test_decode_token_malformed_maps_to_token_invalid_code():
|
||||
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
|
||||
_setup_config()
|
||||
result = decode_token("garbage")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||
assert code == AuthErrorCode.TOKEN_INVALID
|
||||
|
||||
|
||||
# ── Login Response Format ────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_login_response_model_has_no_access_token():
|
||||
"""LoginResponse should NOT contain access_token field (RFC-001)."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
resp = LoginResponse(expires_in=604800)
|
||||
d = resp.model_dump()
|
||||
assert "access_token" not in d
|
||||
assert "expires_in" in d
|
||||
assert d["expires_in"] == 604800
|
||||
|
||||
|
||||
def test_login_response_model_fields():
|
||||
"""LoginResponse has expires_in and needs_setup."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
fields = set(LoginResponse.model_fields.keys())
|
||||
assert fields == {"expires_in", "needs_setup"}
|
||||
|
||||
|
||||
# ── AuthConfig in Route ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_used_in_login_response():
|
||||
"""LoginResponse.expires_in should come from config.token_expiry_days."""
|
||||
from app.gateway.routers.auth import LoginResponse
|
||||
|
||||
expected_seconds = 14 * 24 * 3600
|
||||
resp = LoginResponse(expires_in=expected_seconds)
|
||||
assert resp.expires_in == expected_seconds
|
||||
|
||||
|
||||
# ── UserResponse Type Preservation ───────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_system_role_literal():
|
||||
"""UserResponse.system_role should only accept 'admin' or 'user'."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
# Valid roles
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
|
||||
assert resp.system_role == "admin"
|
||||
|
||||
resp = UserResponse(id="1", email="a@b.com", system_role="user")
|
||||
assert resp.system_role == "user"
|
||||
|
||||
|
||||
def test_user_response_rejects_invalid_role():
|
||||
"""UserResponse should reject invalid system_role values."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="superadmin")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# UNHAPPY PATHS / EDGE CASES
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
# ── get_current_user structured 401 responses ────────────────────────
|
||||
|
||||
|
||||
def test_get_current_user_no_cookie_returns_not_authenticated():
|
||||
"""No cookie → 401 with code=not_authenticated."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "not_authenticated"
|
||||
|
||||
|
||||
def test_get_current_user_expired_token_returns_token_expired():
|
||||
"""Expired token → 401 with code=token_expired."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_get_current_user_invalid_token_returns_token_invalid():
|
||||
"""Bad signature → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_get_current_user_malformed_token_returns_token_invalid():
|
||||
"""Garbage token → 401 with code=token_invalid."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from app.gateway.deps import get_current_user_from_request
|
||||
|
||||
_setup_config()
|
||||
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
asyncio.run(get_current_user_from_request(mock_request))
|
||||
assert exc_info.value.status_code == 401
|
||||
detail = exc_info.value.detail
|
||||
assert detail["code"] == "token_invalid"
|
||||
|
||||
|
||||
# ── decode_token edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_decode_token_empty_string_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token("")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
def test_decode_token_whitespace_returns_malformed():
|
||||
_setup_config()
|
||||
result = decode_token(" ")
|
||||
assert result == TokenError.MALFORMED
|
||||
|
||||
|
||||
# ── AuthConfig validation edge cases ─────────────────────────────────
|
||||
|
||||
|
||||
def test_auth_config_missing_jwt_secret_raises():
|
||||
"""AuthConfig requires jwt_secret — no default allowed."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig()
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_zero_raises():
|
||||
"""token_expiry_days must be >= 1."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=0)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_31_raises():
|
||||
"""token_expiry_days must be <= 30."""
|
||||
with pytest.raises(ValidationError):
|
||||
AuthConfig(jwt_secret="secret", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_1_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
|
||||
assert config.token_expiry_days == 1
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_boundary_30_ok():
|
||||
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
|
||||
assert config.token_expiry_days == 30
|
||||
|
||||
|
||||
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
|
||||
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
# ── CSRF middleware integration (unhappy paths) ──────────────────────
|
||||
|
||||
|
||||
def _make_csrf_app():
|
||||
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
|
||||
from fastapi import HTTPException as _HTTPException
|
||||
from fastapi.responses import JSONResponse as _JSONResponse
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(_HTTPException)
|
||||
async def _http_exc_handler(request, exc):
|
||||
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
app.add_middleware(CSRFMiddleware)
|
||||
|
||||
@app.post("/api/v1/test/protected")
|
||||
async def protected():
|
||||
return {"ok": True}
|
||||
|
||||
@app.post("/api/v1/auth/login/local")
|
||||
async def login():
|
||||
return {"ok": True}
|
||||
|
||||
@app.get("/api/v1/test/read")
|
||||
async def read_endpoint():
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_without_token():
|
||||
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/test/protected")
|
||||
assert resp.status_code == 403
|
||||
assert "CSRF" in resp.json()["detail"]
|
||||
assert "missing" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_blocks_post_with_mismatched_token():
|
||||
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: "token-b"},
|
||||
)
|
||||
assert resp.status_code == 403
|
||||
assert "mismatch" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_post_with_matching_token():
|
||||
"""POST with matching CSRF cookie/header → 200."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
token = secrets.token_urlsafe(64)
|
||||
client.cookies.set(CSRF_COOKIE_NAME, token)
|
||||
resp = client.post(
|
||||
"/api/v1/test/protected",
|
||||
headers={CSRF_HEADER_NAME: token},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_allows_get_without_token():
|
||||
"""GET requests bypass CSRF check."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.get("/api/v1/test/read")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_exempts_login_local():
|
||||
"""POST to login/local is exempt from CSRF (no token yet)."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
|
||||
"""Auth endpoints should receive a CSRF cookie in response."""
|
||||
client = TestClient(_make_csrf_app())
|
||||
resp = client.post("/api/v1/auth/login/local")
|
||||
assert CSRF_COOKIE_NAME in resp.cookies
|
||||
|
||||
|
||||
# ── UserResponse edge cases ──────────────────────────────────────────
|
||||
|
||||
|
||||
def test_user_response_missing_required_fields():
|
||||
"""UserResponse with missing fields → ValidationError."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1") # missing email, system_role
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com") # missing system_role
|
||||
|
||||
|
||||
def test_user_response_empty_string_role_rejected():
|
||||
"""Empty string is not a valid role."""
|
||||
from app.gateway.auth.models import UserResponse
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
UserResponse(id="1", email="a@b.com", system_role="")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
# HTTP-LEVEL API CONTRACT TESTS
|
||||
# ══════════════════════════════════════════════════════════════════════
|
||||
|
||||
|
||||
def _make_auth_app():
|
||||
"""Create FastAPI app with auth routes for contract testing."""
|
||||
from app.gateway.app import create_app
|
||||
|
||||
return create_app()
|
||||
|
||||
|
||||
def _get_auth_client():
|
||||
"""Get TestClient for auth API contract tests."""
|
||||
return TestClient(_make_auth_app())
|
||||
|
||||
|
||||
def test_api_auth_me_no_cookie_returns_structured_401():
|
||||
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "not_authenticated"
|
||||
assert "message" in body["detail"]
|
||||
|
||||
|
||||
def test_api_auth_me_expired_token_returns_structured_401():
|
||||
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
|
||||
_setup_config()
|
||||
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_expired"
|
||||
|
||||
|
||||
def test_api_auth_me_invalid_sig_returns_structured_401():
|
||||
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
|
||||
_setup_config()
|
||||
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||
|
||||
client = _get_auth_client()
|
||||
client.cookies.set("access_token", token)
|
||||
resp = client.get("/api/v1/auth/me")
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "token_invalid"
|
||||
|
||||
|
||||
def test_api_login_bad_credentials_returns_structured_401():
|
||||
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "invalid_credentials"
|
||||
|
||||
|
||||
def test_api_login_success_no_token_in_body():
|
||||
"""Successful login → response body has expires_in but NOT access_token."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
# Register first
|
||||
client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
# Login
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": "contract-test@test.com", "password": "securepassword123"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "expires_in" in body
|
||||
assert "access_token" not in body
|
||||
# Token should be in cookie, not body
|
||||
assert "access_token" in resp.cookies
|
||||
|
||||
|
||||
def test_api_register_duplicate_returns_structured_400():
|
||||
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = "dup-contract-test@test.com"
|
||||
# First register
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||
# Duplicate
|
||||
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "password456"})
|
||||
assert resp.status_code == 400
|
||||
body = resp.json()
|
||||
assert body["detail"]["code"] == "email_already_exists"
|
||||
|
||||
|
||||
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
|
||||
|
||||
|
||||
def _unique_email(prefix: str) -> str:
|
||||
return f"{prefix}-{secrets.token_hex(4)}@test.com"
|
||||
|
||||
|
||||
def _get_set_cookie_headers(resp) -> list[str]:
|
||||
"""Extract all set-cookie header values from a TestClient response."""
|
||||
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
|
||||
|
||||
|
||||
def test_register_http_cookie_httponly_true_secure_false():
|
||||
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("http-cookie"), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" not in cookie_header.lower().replace("samesite", "")
|
||||
|
||||
|
||||
def test_register_https_cookie_httponly_true_secure_true():
|
||||
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("https-cookie"), "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
assert "max-age" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_login_https_sets_secure_cookie():
|
||||
"""HTTPS login → access_token cookie has secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
email = _unique_email("https-login")
|
||||
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||
resp = client.post(
|
||||
"/api/v1/auth/login/local",
|
||||
data={"username": email, "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
cookie_header = resp.headers.get("set-cookie", "")
|
||||
assert "access_token=" in cookie_header
|
||||
assert "httponly" in cookie_header.lower()
|
||||
assert "secure" in cookie_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_secure_on_https():
|
||||
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-https"), "password": "password123"},
|
||||
headers={"x-forwarded-proto": "https"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" in csrf_header.lower()
|
||||
assert "httponly" not in csrf_header.lower()
|
||||
|
||||
|
||||
def test_csrf_cookie_not_secure_on_http():
|
||||
"""HTTP register → csrf_token cookie does NOT have secure flag."""
|
||||
_setup_config()
|
||||
client = _get_auth_client()
|
||||
resp = client.post(
|
||||
"/api/v1/auth/register",
|
||||
json={"email": _unique_email("csrf-http"), "password": "password123"},
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
|
||||
csrf_header = csrf_cookies[0]
|
||||
assert "secure" not in csrf_header.lower().replace("samesite", "")
|
||||
@@ -7,12 +7,12 @@ import json
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.store import ChannelStore
|
||||
|
||||
|
||||
@@ -1718,6 +1718,159 @@ class TestFeishuChannel:
|
||||
_run(go())
|
||||
|
||||
|
||||
class TestWeComChannel:
|
||||
def test_publish_ws_inbound_starts_stream_and_publishes_message(self, monkeypatch):
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = WeComChannel(bus, config={})
|
||||
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
|
||||
|
||||
monkeypatch.setitem(
|
||||
__import__("sys").modules,
|
||||
"aibot",
|
||||
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
|
||||
)
|
||||
|
||||
frame = {
|
||||
"body": {
|
||||
"msgid": "msg-1",
|
||||
"from": {"userid": "user-1"},
|
||||
"aibotid": "bot-1",
|
||||
"chattype": "single",
|
||||
}
|
||||
}
|
||||
files = [{"type": "image", "url": "https://example.com/image.png"}]
|
||||
|
||||
await channel._publish_ws_inbound(frame, "hello", files=files)
|
||||
|
||||
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Working on it...", False)
|
||||
bus.publish_inbound.assert_awaited_once()
|
||||
|
||||
inbound = bus.publish_inbound.await_args.args[0]
|
||||
assert inbound.channel_name == "wecom"
|
||||
assert inbound.chat_id == "user-1"
|
||||
assert inbound.user_id == "user-1"
|
||||
assert inbound.text == "hello"
|
||||
assert inbound.thread_ts == "msg-1"
|
||||
assert inbound.topic_id == "user-1"
|
||||
assert inbound.files == files
|
||||
assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single"}
|
||||
assert channel._ws_frames["msg-1"] is frame
|
||||
assert channel._ws_stream_ids["msg-1"] == "stream-1"
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_publish_ws_inbound_uses_configured_working_message(self, monkeypatch):
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = WeComChannel(bus, config={"working_message": "Please wait..."})
|
||||
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
|
||||
channel._working_message = "Please wait..."
|
||||
|
||||
monkeypatch.setitem(
|
||||
__import__("sys").modules,
|
||||
"aibot",
|
||||
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
|
||||
)
|
||||
|
||||
frame = {
|
||||
"body": {
|
||||
"msgid": "msg-1",
|
||||
"from": {"userid": "user-1"},
|
||||
}
|
||||
}
|
||||
|
||||
await channel._publish_ws_inbound(frame, "hello")
|
||||
|
||||
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Please wait...", False)
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_on_outbound_sends_attachment_before_clearing_context(self, tmp_path):
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
channel = WeComChannel(bus, config={})
|
||||
|
||||
frame = {"body": {"msgid": "msg-1"}}
|
||||
ws_client = SimpleNamespace(
|
||||
reply_stream=AsyncMock(),
|
||||
reply=AsyncMock(),
|
||||
)
|
||||
channel._ws_client = ws_client
|
||||
channel._ws_frames["msg-1"] = frame
|
||||
channel._ws_stream_ids["msg-1"] = "stream-1"
|
||||
channel._upload_media_ws = AsyncMock(return_value="media-1")
|
||||
|
||||
attachment_path = tmp_path / "image.png"
|
||||
attachment_path.write_bytes(b"png")
|
||||
attachment = ResolvedAttachment(
|
||||
virtual_path="/mnt/user-data/outputs/image.png",
|
||||
actual_path=attachment_path,
|
||||
filename="image.png",
|
||||
mime_type="image/png",
|
||||
size=attachment_path.stat().st_size,
|
||||
is_image=True,
|
||||
)
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel_name="wecom",
|
||||
chat_id="user-1",
|
||||
thread_id="thread-1",
|
||||
text="done",
|
||||
attachments=[attachment],
|
||||
is_final=True,
|
||||
thread_ts="msg-1",
|
||||
)
|
||||
|
||||
await channel._on_outbound(msg)
|
||||
|
||||
ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "done", True)
|
||||
channel._upload_media_ws.assert_awaited_once_with(
|
||||
media_type="image",
|
||||
filename="image.png",
|
||||
path=str(attachment_path),
|
||||
size=attachment.size,
|
||||
)
|
||||
ws_client.reply.assert_awaited_once_with(frame, {"image": {"media_id": "media-1"}, "msgtype": "image"})
|
||||
assert "msg-1" not in channel._ws_frames
|
||||
assert "msg-1" not in channel._ws_stream_ids
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_send_falls_back_to_send_message_without_thread_context(self):
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
channel = WeComChannel(bus, config={})
|
||||
channel._ws_client = SimpleNamespace(send_message=AsyncMock())
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel_name="wecom",
|
||||
chat_id="user-1",
|
||||
thread_id="thread-1",
|
||||
text="hello",
|
||||
thread_ts=None,
|
||||
)
|
||||
|
||||
await channel.send(msg)
|
||||
|
||||
channel._ws_client.send_message.assert_awaited_once_with(
|
||||
"user-1",
|
||||
{"msgtype": "markdown", "markdown": {"content": "hello"}},
|
||||
)
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
class TestChannelService:
|
||||
def test_get_status_no_channels(self):
|
||||
from app.channels.service import ChannelService
|
||||
@@ -1835,6 +1988,47 @@ class TestSlackSendRetry:
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
class TestSlackAllowedUsers:
|
||||
def test_numeric_allowed_users_match_string_event_user_id(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
bus = MessageBus()
|
||||
bus.publish_inbound = AsyncMock()
|
||||
channel = SlackChannel(
|
||||
bus=bus,
|
||||
config={"allowed_users": [123456]},
|
||||
)
|
||||
channel._loop = MagicMock()
|
||||
channel._loop.is_running.return_value = True
|
||||
channel._add_reaction = MagicMock()
|
||||
channel._send_running_reply = MagicMock()
|
||||
|
||||
event = {
|
||||
"user": "123456",
|
||||
"text": "hello from slack",
|
||||
"channel": "C123",
|
||||
"ts": "1710000000.000100",
|
||||
}
|
||||
|
||||
def submit_coro(coro, loop):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
with patch(
|
||||
"app.channels.slack.asyncio.run_coroutine_threadsafe",
|
||||
side_effect=submit_coro,
|
||||
) as submit:
|
||||
channel._handle_message_event(event)
|
||||
|
||||
channel._add_reaction.assert_called_once_with("C123", "1710000000.000100", "eyes")
|
||||
channel._send_running_reply.assert_called_once_with("C123", "1710000000.000100")
|
||||
submit.assert_called_once()
|
||||
inbound = bus.publish_inbound.call_args.args[0]
|
||||
assert inbound.user_id == "123456"
|
||||
assert inbound.chat_id == "C123"
|
||||
assert inbound.text == "hello from slack"
|
||||
|
||||
def test_raises_after_all_retries_exhausted(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
@@ -1854,6 +2048,20 @@ class TestSlackSendRetry:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_raises_runtime_error_when_no_attempts_configured(self):
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
ch = SlackChannel(bus=bus, config={"bot_token": "xoxb-test", "app_token": "xapp-test"})
|
||||
ch._web_client = MagicMock()
|
||||
|
||||
msg = OutboundMessage(channel_name="slack", chat_id="C123", thread_id="t1", text="hello")
|
||||
with pytest.raises(RuntimeError, match="without an exception"):
|
||||
await ch.send(msg, _max_retries=0)
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram send retry tests
|
||||
@@ -1912,6 +2120,36 @@ class TestTelegramSendRetry:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_raises_runtime_error_when_no_attempts_configured(self):
|
||||
from app.channels.telegram import TelegramChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
ch = TelegramChannel(bus=bus, config={"bot_token": "test-token"})
|
||||
ch._application = MagicMock()
|
||||
|
||||
msg = OutboundMessage(channel_name="telegram", chat_id="12345", thread_id="t1", text="hello")
|
||||
with pytest.raises(RuntimeError, match="without an exception"):
|
||||
await ch.send(msg, _max_retries=0)
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
class TestFeishuSendRetry:
|
||||
def test_raises_runtime_error_when_no_attempts_configured(self):
|
||||
from app.channels.feishu import FeishuChannel
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
ch = FeishuChannel(bus=bus, config={"app_id": "id", "app_secret": "secret"})
|
||||
ch._api_client = MagicMock()
|
||||
|
||||
msg = OutboundMessage(channel_name="feishu", chat_id="chat", thread_id="t1", text="hello")
|
||||
with pytest.raises(RuntimeError, match="without an exception"):
|
||||
await ch.send(msg, _max_retries=0)
|
||||
|
||||
_run(go())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Telegram private-chat thread context tests
|
||||
|
||||
@@ -59,18 +59,20 @@ class TestClientInit:
|
||||
assert client._subagent_enabled is False
|
||||
assert client._plan_mode is False
|
||||
assert client._agent_name is None
|
||||
assert client._available_skills is None
|
||||
assert client._checkpointer is None
|
||||
assert client._agent is None
|
||||
|
||||
def test_custom_params(self, mock_app_config):
|
||||
mock_middleware = MagicMock()
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", middlewares=[mock_middleware])
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
|
||||
assert c._model_name == "gpt-4"
|
||||
assert c._thinking_enabled is False
|
||||
assert c._subagent_enabled is True
|
||||
assert c._plan_mode is True
|
||||
assert c._agent_name == "test-agent"
|
||||
assert c._available_skills == {"skill1", "skill2"}
|
||||
assert c._middlewares == [mock_middleware]
|
||||
|
||||
def test_invalid_agent_name(self, mock_app_config):
|
||||
@@ -394,8 +396,10 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._agent_name = "custom-agent"
|
||||
client._available_skills = {"test_skill"}
|
||||
client._ensure_agent(config)
|
||||
|
||||
assert client._agent is mock_agent
|
||||
@@ -404,6 +408,7 @@ class TestEnsureAgent:
|
||||
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
|
||||
mock_apply_prompt.assert_called_once()
|
||||
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
|
||||
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
|
||||
|
||||
def test_uses_default_checkpointer_when_available(self, client):
|
||||
mock_agent = MagicMock()
|
||||
@@ -441,6 +446,7 @@ class TestEnsureAgent:
|
||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
@@ -469,7 +475,7 @@ class TestEnsureAgent:
|
||||
"""_ensure_agent does not recreate if config key unchanged."""
|
||||
mock_agent = MagicMock()
|
||||
client._agent = mock_agent
|
||||
client._agent_config_key = (None, True, False, False)
|
||||
client._agent_config_key = (None, True, False, False, None, None)
|
||||
|
||||
config = client._get_runnable_config("t1")
|
||||
client._ensure_agent(config)
|
||||
@@ -1276,6 +1282,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config_a)
|
||||
first_agent = client._agent
|
||||
@@ -1303,6 +1310,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client._ensure_agent(config)
|
||||
@@ -1327,6 +1335,7 @@ class TestScenarioAgentRecreation:
|
||||
patch("deerflow.client._build_middlewares", return_value=[]),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
client.reset_agent()
|
||||
|
||||
@@ -439,6 +439,15 @@ class TestAgentsAPI:
|
||||
assert "agent-one" in names
|
||||
assert "agent-two" in names
|
||||
|
||||
def test_list_agents_includes_soul(self, agent_client):
|
||||
agent_client.post("/api/agents", json={"name": "soul-agent", "soul": "My soul content"})
|
||||
|
||||
response = agent_client.get("/api/agents")
|
||||
assert response.status_code == 200
|
||||
agents = response.json()["agents"]
|
||||
soul_agent = next(a for a in agents if a["name"] == "soul-agent")
|
||||
assert soul_agent["soul"] == "My soul content"
|
||||
|
||||
def test_get_agent(self, agent_client):
|
||||
agent_client.post("/api/agents", json={"name": "test-agent", "soul": "Hello world"})
|
||||
|
||||
|
||||
@@ -0,0 +1,214 @@
|
||||
"""Tests for _ensure_admin_user() in app.py.
|
||||
|
||||
Covers: first-boot admin creation, auto-reset on needs_setup=True,
|
||||
no-op on needs_setup=False, migration, and edge cases.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _make_app_stub(store=None):
|
||||
"""Minimal app-like object with state.store."""
|
||||
app = SimpleNamespace()
|
||||
app.state = SimpleNamespace()
|
||||
app.state.store = store
|
||||
return app
|
||||
|
||||
|
||||
def _make_provider(user_count=0, admin_user=None):
|
||||
p = AsyncMock()
|
||||
p.count_users = AsyncMock(return_value=user_count)
|
||||
p.create_user = AsyncMock(
|
||||
side_effect=lambda **kw: User(
|
||||
email=kw["email"],
|
||||
password_hash="hashed",
|
||||
system_role=kw.get("system_role", "user"),
|
||||
needs_setup=kw.get("needs_setup", False),
|
||||
)
|
||||
)
|
||||
p.get_user_by_email = AsyncMock(return_value=admin_user)
|
||||
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||
return p
|
||||
|
||||
|
||||
# ── First boot: no users ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_first_boot_creates_admin():
|
||||
"""count_users==0 → create admin with needs_setup=True."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
call_kwargs = provider.create_user.call_args[1]
|
||||
assert call_kwargs["email"] == "admin@deerflow.dev"
|
||||
assert call_kwargs["system_role"] == "admin"
|
||||
assert call_kwargs["needs_setup"] is True
|
||||
assert len(call_kwargs["password"]) > 10 # random password generated
|
||||
|
||||
|
||||
def test_first_boot_triggers_migration_if_store_present():
|
||||
"""First boot with store → _migrate_orphaned_threads called."""
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(return_value=[])
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
store.asearch.assert_called_once()
|
||||
|
||||
|
||||
def test_first_boot_no_store_skips_migration():
|
||||
"""First boot without store → no crash, migration skipped."""
|
||||
provider = _make_provider(user_count=0)
|
||||
app = _make_app_stub(store=None)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
|
||||
|
||||
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_true_resets_password():
|
||||
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="old-hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=0,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
# Password was reset
|
||||
provider.update_user.assert_called_once()
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.password_hash == "new-hash"
|
||||
assert updated.token_version == 1
|
||||
|
||||
|
||||
def test_needs_setup_true_consecutive_resets_increment_version():
|
||||
"""Two boots with needs_setup=True → token_version increments each time."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="hash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
updated = provider.update_user.call_args[0][0]
|
||||
assert updated.token_version == 4
|
||||
|
||||
|
||||
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
|
||||
|
||||
|
||||
def test_needs_setup_false_no_reset():
|
||||
"""Admin with needs_setup=False → no password reset, no update."""
|
||||
admin = User(
|
||||
email="admin@deerflow.dev",
|
||||
password_hash="stable-hash",
|
||||
system_role="admin",
|
||||
needs_setup=False,
|
||||
token_version=2,
|
||||
)
|
||||
provider = _make_provider(user_count=1, admin_user=admin)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
assert admin.password_hash == "stable-hash"
|
||||
assert admin.token_version == 2
|
||||
|
||||
|
||||
# ── Edge cases ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_admin_email_found_no_crash():
|
||||
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
|
||||
provider = _make_provider(user_count=3, admin_user=None)
|
||||
app = _make_app_stub()
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.update_user.assert_not_called()
|
||||
provider.create_user.assert_not_called()
|
||||
|
||||
|
||||
def test_migration_failure_is_non_fatal():
|
||||
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||
provider = _make_provider(user_count=0)
|
||||
store = AsyncMock()
|
||||
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||
app = _make_app_stub(store=store)
|
||||
|
||||
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||
from app.gateway.app import _ensure_admin_user
|
||||
|
||||
# Should not raise
|
||||
asyncio.run(_ensure_admin_user(app))
|
||||
|
||||
provider.create_user.assert_called_once()
|
||||
@@ -0,0 +1,459 @@
|
||||
"""Tests for file_conversion utilities (PR1: pymupdf4llm + asyncio.to_thread; PR2: extract_outline)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from deerflow.utils.file_conversion import (
|
||||
_ASYNC_THRESHOLD_BYTES,
|
||||
_MIN_CHARS_PER_PAGE,
|
||||
MAX_OUTLINE_ENTRIES,
|
||||
_do_convert,
|
||||
_pymupdf_output_too_sparse,
|
||||
convert_file_to_markdown,
|
||||
extract_outline,
|
||||
)
|
||||
|
||||
|
||||
def _make_pymupdf_mock(page_count: int) -> ModuleType:
|
||||
"""Return a fake *pymupdf* module whose ``open()`` reports *page_count* pages."""
|
||||
mock_doc = MagicMock()
|
||||
mock_doc.__len__ = MagicMock(return_value=page_count)
|
||||
fake_pymupdf = ModuleType("pymupdf")
|
||||
fake_pymupdf.open = MagicMock(return_value=mock_doc) # type: ignore[attr-defined]
|
||||
return fake_pymupdf
|
||||
|
||||
|
||||
def _run(coro):
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
return loop.run_until_complete(coro)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _pymupdf_output_too_sparse
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPymupdfOutputTooSparse:
|
||||
"""Check the chars-per-page sparsity heuristic."""
|
||||
|
||||
def test_dense_text_pdf_not_sparse(self, tmp_path):
|
||||
"""Normal text PDF: many chars per page → not sparse."""
|
||||
pdf = tmp_path / "dense.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
# 10 pages × 10 000 chars → 1000/page ≫ threshold
|
||||
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=10)}):
|
||||
result = _pymupdf_output_too_sparse("x" * 10_000, pdf)
|
||||
assert result is False
|
||||
|
||||
def test_image_based_pdf_is_sparse(self, tmp_path):
|
||||
"""Image-based PDF: near-zero chars per page → sparse."""
|
||||
pdf = tmp_path / "image.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
# 612 chars / 31 pages ≈ 19.7/page < _MIN_CHARS_PER_PAGE (50)
|
||||
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=31)}):
|
||||
result = _pymupdf_output_too_sparse("x" * 612, pdf)
|
||||
assert result is True
|
||||
|
||||
def test_fallback_when_pymupdf_unavailable(self, tmp_path):
|
||||
"""When pymupdf is not installed, fall back to absolute 200-char threshold."""
|
||||
pdf = tmp_path / "broken.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
# Remove pymupdf from sys.modules so the `import pymupdf` inside the
|
||||
# function raises ImportError, triggering the absolute-threshold fallback.
|
||||
with patch.dict(sys.modules, {"pymupdf": None}):
|
||||
sparse = _pymupdf_output_too_sparse("x" * 100, pdf)
|
||||
not_sparse = _pymupdf_output_too_sparse("x" * 300, pdf)
|
||||
|
||||
assert sparse is True
|
||||
assert not_sparse is False
|
||||
|
||||
def test_exactly_at_threshold_is_not_sparse(self, tmp_path):
|
||||
"""Chars-per-page == threshold is treated as NOT sparse (boundary inclusive)."""
|
||||
pdf = tmp_path / "boundary.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
# 2 pages × _MIN_CHARS_PER_PAGE chars = exactly at threshold
|
||||
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=2)}):
|
||||
result = _pymupdf_output_too_sparse("x" * (_MIN_CHARS_PER_PAGE * 2), pdf)
|
||||
assert result is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _do_convert — routing logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDoConvert:
|
||||
"""Verify that _do_convert routes to the right sub-converter."""
|
||||
|
||||
def test_non_pdf_always_uses_markitdown(self, tmp_path):
|
||||
"""DOCX / XLSX / PPTX always go through MarkItDown regardless of setting."""
|
||||
docx = tmp_path / "report.docx"
|
||||
docx.write_bytes(b"PK fake docx")
|
||||
|
||||
with patch(
|
||||
"deerflow.utils.file_conversion._convert_with_markitdown",
|
||||
return_value="# Markdown from MarkItDown",
|
||||
) as mock_md:
|
||||
result = _do_convert(docx, "auto")
|
||||
|
||||
mock_md.assert_called_once_with(docx)
|
||||
assert result == "# Markdown from MarkItDown"
|
||||
|
||||
def test_pdf_auto_uses_pymupdf4llm_when_dense(self, tmp_path):
|
||||
"""auto mode: use pymupdf4llm output when it's dense enough."""
|
||||
pdf = tmp_path / "report.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
dense_text = "# Heading\n" + "word " * 2000 # clearly dense
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
|
||||
return_value=dense_text,
|
||||
),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
|
||||
return_value=False,
|
||||
),
|
||||
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
|
||||
):
|
||||
result = _do_convert(pdf, "auto")
|
||||
|
||||
mock_md.assert_not_called()
|
||||
assert result == dense_text
|
||||
|
||||
def test_pdf_auto_falls_back_when_sparse(self, tmp_path):
|
||||
"""auto mode: fall back to MarkItDown when pymupdf4llm output is sparse."""
|
||||
pdf = tmp_path / "scanned.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
|
||||
return_value="x" * 612, # 19.7 chars/page for 31-page doc
|
||||
),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_with_markitdown",
|
||||
return_value="OCR result via MarkItDown",
|
||||
) as mock_md,
|
||||
):
|
||||
result = _do_convert(pdf, "auto")
|
||||
|
||||
mock_md.assert_called_once_with(pdf)
|
||||
assert result == "OCR result via MarkItDown"
|
||||
|
||||
def test_pdf_explicit_pymupdf4llm_skips_sparsity_check(self, tmp_path):
|
||||
"""'pymupdf4llm' mode: use output as-is even if sparse."""
|
||||
pdf = tmp_path / "explicit.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
sparse_text = "x" * 10 # very short
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
|
||||
return_value=sparse_text,
|
||||
),
|
||||
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
|
||||
):
|
||||
result = _do_convert(pdf, "pymupdf4llm")
|
||||
|
||||
mock_md.assert_not_called()
|
||||
assert result == sparse_text
|
||||
|
||||
def test_pdf_explicit_markitdown_skips_pymupdf4llm(self, tmp_path):
|
||||
"""'markitdown' mode: never attempt pymupdf4llm."""
|
||||
pdf = tmp_path / "force_md.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm") as mock_pymu,
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_with_markitdown",
|
||||
return_value="MarkItDown result",
|
||||
),
|
||||
):
|
||||
result = _do_convert(pdf, "markitdown")
|
||||
|
||||
mock_pymu.assert_not_called()
|
||||
assert result == "MarkItDown result"
|
||||
|
||||
def test_pdf_auto_falls_back_when_pymupdf4llm_not_installed(self, tmp_path):
|
||||
"""auto mode: if pymupdf4llm is not installed, use MarkItDown directly."""
|
||||
pdf = tmp_path / "no_pymupdf.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
|
||||
return_value=None, # None signals not installed
|
||||
),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._convert_with_markitdown",
|
||||
return_value="MarkItDown fallback",
|
||||
) as mock_md,
|
||||
):
|
||||
result = _do_convert(pdf, "auto")
|
||||
|
||||
mock_md.assert_called_once_with(pdf)
|
||||
assert result == "MarkItDown fallback"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# convert_file_to_markdown — async + file writing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConvertFileToMarkdown:
|
||||
def test_small_file_runs_synchronously(self, tmp_path):
|
||||
"""Small files (< 1 MB) are converted in the event loop thread."""
|
||||
pdf = tmp_path / "small.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 " + b"x" * 100) # well under 1 MB
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._do_convert",
|
||||
return_value="# Small PDF",
|
||||
) as mock_convert,
|
||||
patch("asyncio.to_thread") as mock_thread,
|
||||
):
|
||||
md_path = _run(convert_file_to_markdown(pdf))
|
||||
|
||||
# asyncio.to_thread must NOT have been called
|
||||
mock_thread.assert_not_called()
|
||||
mock_convert.assert_called_once()
|
||||
assert md_path == pdf.with_suffix(".md")
|
||||
assert md_path.read_text() == "# Small PDF"
|
||||
|
||||
def test_large_file_offloaded_to_thread(self, tmp_path):
|
||||
"""Large files (> 1 MB) are offloaded via asyncio.to_thread."""
|
||||
pdf = tmp_path / "large.pdf"
|
||||
# Write slightly more than the threshold
|
||||
pdf.write_bytes(b"%PDF-1.4 " + b"x" * (_ASYNC_THRESHOLD_BYTES + 1))
|
||||
|
||||
async def fake_to_thread(fn, *args, **kwargs):
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._do_convert",
|
||||
return_value="# Large PDF",
|
||||
),
|
||||
patch("asyncio.to_thread", side_effect=fake_to_thread) as mock_thread,
|
||||
):
|
||||
md_path = _run(convert_file_to_markdown(pdf))
|
||||
|
||||
mock_thread.assert_called_once()
|
||||
assert md_path == pdf.with_suffix(".md")
|
||||
assert md_path.read_text() == "# Large PDF"
|
||||
|
||||
def test_returns_none_on_conversion_error(self, tmp_path):
|
||||
"""If conversion raises, return None without propagating the exception."""
|
||||
pdf = tmp_path / "broken.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._do_convert",
|
||||
side_effect=RuntimeError("conversion failed"),
|
||||
),
|
||||
):
|
||||
result = _run(convert_file_to_markdown(pdf))
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_writes_utf8_markdown_file(self, tmp_path):
|
||||
"""Generated .md file is written with UTF-8 encoding."""
|
||||
pdf = tmp_path / "report.pdf"
|
||||
pdf.write_bytes(b"%PDF-1.4 fake")
|
||||
chinese_content = "# 中文报告\n\n这是测试内容。"
|
||||
|
||||
with (
|
||||
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
|
||||
patch(
|
||||
"deerflow.utils.file_conversion._do_convert",
|
||||
return_value=chinese_content,
|
||||
),
|
||||
):
|
||||
md_path = _run(convert_file_to_markdown(pdf))
|
||||
|
||||
assert md_path is not None
|
||||
assert md_path.read_text(encoding="utf-8") == chinese_content
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# extract_outline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestExtractOutline:
|
||||
"""Tests for extract_outline()."""
|
||||
|
||||
def test_empty_file_returns_empty(self, tmp_path):
|
||||
"""Empty markdown file yields no outline entries."""
|
||||
md = tmp_path / "empty.md"
|
||||
md.write_text("", encoding="utf-8")
|
||||
assert extract_outline(md) == []
|
||||
|
||||
def test_missing_file_returns_empty(self, tmp_path):
|
||||
"""Non-existent path returns [] without raising."""
|
||||
assert extract_outline(tmp_path / "nonexistent.md") == []
|
||||
|
||||
def test_standard_markdown_headings(self, tmp_path):
|
||||
"""# / ## / ### headings are all recognised."""
|
||||
md = tmp_path / "doc.md"
|
||||
md.write_text(
|
||||
"# Chapter One\n\nSome text.\n\n## Section 1.1\n\nMore text.\n\n### Sub 1.1.1\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 3
|
||||
assert outline[0] == {"title": "Chapter One", "line": 1}
|
||||
assert outline[1] == {"title": "Section 1.1", "line": 5}
|
||||
assert outline[2] == {"title": "Sub 1.1.1", "line": 9}
|
||||
|
||||
def test_bold_sec_item_heading(self, tmp_path):
|
||||
"""**ITEM N. TITLE** lines in SEC filings are recognised."""
|
||||
md = tmp_path / "10k.md"
|
||||
md.write_text(
|
||||
"Cover page text.\n\n**ITEM 1. BUSINESS**\n\nBody.\n\n**ITEM 1A. RISK FACTORS**\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 2
|
||||
assert outline[0] == {"title": "ITEM 1. BUSINESS", "line": 3}
|
||||
assert outline[1] == {"title": "ITEM 1A. RISK FACTORS", "line": 7}
|
||||
|
||||
def test_bold_part_heading(self, tmp_path):
|
||||
"""**PART I** / **PART II** headings are recognised."""
|
||||
md = tmp_path / "10k.md"
|
||||
md.write_text("**PART I**\n\n**PART II**\n\n**PART III**\n", encoding="utf-8")
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 3
|
||||
titles = [e["title"] for e in outline]
|
||||
assert "PART I" in titles
|
||||
assert "PART II" in titles
|
||||
assert "PART III" in titles
|
||||
|
||||
def test_sec_cover_page_boilerplate_excluded(self, tmp_path):
|
||||
"""Address lines and short cover boilerplate must NOT appear in outline."""
|
||||
md = tmp_path / "8k.md"
|
||||
md.write_text(
|
||||
"## **UNITED STATES SECURITIES AND EXCHANGE COMMISSION**\n\n**WASHINGTON, DC 20549**\n\n**CURRENT REPORT**\n\n**SIGNATURES**\n\n**TESLA, INC.**\n\n**ITEM 2.02. RESULTS OF OPERATIONS**\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
titles = [e["title"] for e in outline]
|
||||
# Cover-page boilerplate should be excluded
|
||||
assert "WASHINGTON, DC 20549" not in titles
|
||||
assert "CURRENT REPORT" not in titles
|
||||
assert "SIGNATURES" not in titles
|
||||
assert "TESLA, INC." not in titles
|
||||
# Real SEC heading must be included
|
||||
assert "ITEM 2.02. RESULTS OF OPERATIONS" in titles
|
||||
|
||||
def test_chinese_headings_via_standard_markdown(self, tmp_path):
|
||||
"""Chinese annual report headings emitted as # by pymupdf4llm are captured."""
|
||||
md = tmp_path / "annual.md"
|
||||
md.write_text(
|
||||
"# 第一节 公司简介\n\n内容。\n\n## 第三节 管理层讨论与分析\n\n分析内容。\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 2
|
||||
assert outline[0]["title"] == "第一节 公司简介"
|
||||
assert outline[1]["title"] == "第三节 管理层讨论与分析"
|
||||
|
||||
def test_outline_capped_at_max_entries(self, tmp_path):
|
||||
"""When truncated, result has MAX_OUTLINE_ENTRIES real entries + 1 sentinel."""
|
||||
lines = [f"# Heading {i}" for i in range(MAX_OUTLINE_ENTRIES + 10)]
|
||||
md = tmp_path / "long.md"
|
||||
md.write_text("\n".join(lines), encoding="utf-8")
|
||||
outline = extract_outline(md)
|
||||
# Last entry is the truncation sentinel
|
||||
assert outline[-1] == {"truncated": True}
|
||||
# Visible entries are exactly MAX_OUTLINE_ENTRIES
|
||||
visible = [e for e in outline if not e.get("truncated")]
|
||||
assert len(visible) == MAX_OUTLINE_ENTRIES
|
||||
|
||||
def test_no_truncation_sentinel_when_under_limit(self, tmp_path):
|
||||
"""Short documents produce no sentinel entry."""
|
||||
lines = [f"# Heading {i}" for i in range(5)]
|
||||
md = tmp_path / "short.md"
|
||||
md.write_text("\n".join(lines), encoding="utf-8")
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 5
|
||||
assert not any(e.get("truncated") for e in outline)
|
||||
|
||||
def test_blank_lines_and_whitespace_ignored(self, tmp_path):
|
||||
"""Blank lines between headings do not produce empty entries."""
|
||||
md = tmp_path / "spaced.md"
|
||||
md.write_text("\n\n# Title One\n\n\n\n# Title Two\n\n", encoding="utf-8")
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 2
|
||||
assert all(e["title"] for e in outline)
|
||||
|
||||
def test_inline_bold_not_confused_with_heading(self, tmp_path):
|
||||
"""Mid-sentence bold text must not be mistaken for a heading."""
|
||||
md = tmp_path / "prose.md"
|
||||
md.write_text(
|
||||
"This sentence has **bold words** inside it.\n\nAnother with **MULTIPLE CAPS** inline.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert outline == []
|
||||
|
||||
def test_split_bold_heading_academic_paper(self, tmp_path):
|
||||
"""**<num>** **<title>** lines from academic papers are recognised (Style 3)."""
|
||||
md = tmp_path / "paper.md"
|
||||
md.write_text(
|
||||
"## **Attention Is All You Need**\n\n**1** **Introduction**\n\nBody text.\n\n**2** **Background**\n\nMore text.\n\n**3.1** **Encoder and Decoder Stacks**\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
titles = [e["title"] for e in outline]
|
||||
assert "1 Introduction" in titles
|
||||
assert "2 Background" in titles
|
||||
assert "3.1 Encoder and Decoder Stacks" in titles
|
||||
|
||||
def test_split_bold_year_columns_excluded(self, tmp_path):
|
||||
"""Financial table headers like **2023** **2022** **2021** are NOT headings."""
|
||||
md = tmp_path / "annual.md"
|
||||
md.write_text(
|
||||
"# Financial Summary\n\n**2023** **2022** **2021**\n\nRevenue 100 90 80\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
titles = [e["title"] for e in outline]
|
||||
# Only the # heading should appear, not the year-column row
|
||||
assert titles == ["Financial Summary"]
|
||||
|
||||
def test_adjacent_bold_spans_merged_in_markdown_heading(self, tmp_path):
|
||||
"""** ** artefacts inside a # heading are merged into clean plain text."""
|
||||
md = tmp_path / "sec.md"
|
||||
md.write_text(
|
||||
"## **UNITED STATES** **SECURITIES AND EXCHANGE COMMISSION**\n\nBody text.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
outline = extract_outline(md)
|
||||
assert len(outline) == 1
|
||||
# Title must be clean — no ** ** artefacts
|
||||
assert outline[0]["title"] == "UNITED STATES SECURITIES AND EXCHANGE COMMISSION"
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
from deerflow.config.acp_config import ACPAgentConfig
|
||||
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
|
||||
from deerflow.tools.builtins.invoke_acp_agent_tool import (
|
||||
_build_acp_mcp_servers,
|
||||
_build_mcp_servers,
|
||||
_build_permission_response,
|
||||
_get_work_dir,
|
||||
@@ -42,6 +43,43 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
|
||||
def test_build_acp_mcp_servers_formats_list_payload():
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
|
||||
fresh_config = ExtensionsConfig(
|
||||
mcp_servers={
|
||||
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
|
||||
"http": McpServerConfig(enabled=True, type="http", url="https://example.com/mcp", headers={"Authorization": "Bearer token"}),
|
||||
"disabled": McpServerConfig(enabled=False, type="stdio", command="echo"),
|
||||
},
|
||||
skills={},
|
||||
)
|
||||
monkeypatch = pytest.MonkeyPatch()
|
||||
monkeypatch.setattr(
|
||||
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
|
||||
classmethod(lambda cls: fresh_config),
|
||||
)
|
||||
|
||||
try:
|
||||
assert _build_acp_mcp_servers() == [
|
||||
{
|
||||
"name": "stdio",
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["srv"],
|
||||
"env": [{"name": "FOO", "value": "bar"}],
|
||||
},
|
||||
{
|
||||
"name": "http",
|
||||
"type": "http",
|
||||
"url": "https://example.com/mcp",
|
||||
"headers": [{"name": "Authorization", "value": "Bearer token"}],
|
||||
},
|
||||
]
|
||||
finally:
|
||||
monkeypatch.undo()
|
||||
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
|
||||
|
||||
|
||||
def test_build_permission_response_prefers_allow_once():
|
||||
response = _build_permission_response(
|
||||
[
|
||||
@@ -251,9 +289,15 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
|
||||
assert captured["spawn"] == {"cmd": "codex-acp", "args": ["--json"], "cwd": expected_cwd}
|
||||
assert captured["new_session"] == {
|
||||
"cwd": expected_cwd,
|
||||
"mcp_servers": {
|
||||
"github": {"transport": "stdio", "command": "npx", "args": ["github-mcp"]},
|
||||
},
|
||||
"mcp_servers": [
|
||||
{
|
||||
"name": "github",
|
||||
"type": "stdio",
|
||||
"command": "npx",
|
||||
"args": ["github-mcp"],
|
||||
"env": [],
|
||||
}
|
||||
],
|
||||
"model": "gpt-5-codex",
|
||||
}
|
||||
assert captured["prompt"] == {
|
||||
@@ -448,6 +492,94 @@ async def test_invoke_acp_agent_passes_env_to_spawn(monkeypatch, tmp_path):
|
||||
assert captured["env"] == {"OPENAI_API_KEY": "sk-from-env", "FOO": "bar"}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invoke_acp_agent_skips_invalid_mcp_servers(monkeypatch, tmp_path, caplog):
|
||||
"""Invalid MCP config should be logged and skipped instead of failing ACP invocation."""
|
||||
from deerflow.config import paths as paths_module
|
||||
|
||||
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
|
||||
monkeypatch.setattr(
|
||||
"deerflow.tools.builtins.invoke_acp_agent_tool._build_acp_mcp_servers",
|
||||
lambda: (_ for _ in ()).throw(ValueError("missing command")),
|
||||
)
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
|
||||
class DummyClient:
|
||||
def __init__(self) -> None:
|
||||
self._chunks: list[str] = []
|
||||
|
||||
@property
|
||||
def collected_text(self) -> str:
|
||||
return ""
|
||||
|
||||
async def session_update(self, session_id, update, **kwargs):
|
||||
pass
|
||||
|
||||
async def request_permission(self, options, session_id, tool_call, **kwargs):
|
||||
raise AssertionError("should not be called")
|
||||
|
||||
class DummyConn:
|
||||
async def initialize(self, **kwargs):
|
||||
pass
|
||||
|
||||
async def new_session(self, **kwargs):
|
||||
captured["new_session"] = kwargs
|
||||
return SimpleNamespace(session_id="s1")
|
||||
|
||||
async def prompt(self, **kwargs):
|
||||
pass
|
||||
|
||||
class DummyProcessContext:
|
||||
def __init__(self, client, cmd, *args, env=None, cwd=None):
|
||||
captured["spawn"] = {"cmd": cmd, "args": list(args), "env": env, "cwd": cwd}
|
||||
|
||||
async def __aenter__(self):
|
||||
return DummyConn(), object()
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
class DummyRequestError(Exception):
|
||||
@staticmethod
|
||||
def method_not_found(method):
|
||||
return DummyRequestError(method)
|
||||
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"acp",
|
||||
SimpleNamespace(
|
||||
PROTOCOL_VERSION="2026-03-24",
|
||||
Client=DummyClient,
|
||||
RequestError=DummyRequestError,
|
||||
spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd),
|
||||
text_block=lambda text: {"type": "text", "text": text},
|
||||
),
|
||||
)
|
||||
monkeypatch.setitem(
|
||||
sys.modules,
|
||||
"acp.schema",
|
||||
SimpleNamespace(
|
||||
ClientCapabilities=lambda: {},
|
||||
Implementation=lambda **kwargs: kwargs,
|
||||
TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}),
|
||||
),
|
||||
)
|
||||
|
||||
tool = build_invoke_acp_agent_tool({"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")})
|
||||
caplog.set_level("WARNING")
|
||||
|
||||
try:
|
||||
await tool.coroutine(agent="codex", prompt="Do something")
|
||||
finally:
|
||||
sys.modules.pop("acp", None)
|
||||
sys.modules.pop("acp.schema", None)
|
||||
|
||||
assert captured["new_session"]["mcp_servers"] == []
|
||||
assert "continuing without MCP servers" in caplog.text
|
||||
assert "missing command" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch, tmp_path):
|
||||
"""When env is empty, None is passed to spawn_agent_process (subprocess inherits parent env)."""
|
||||
|
||||
@@ -0,0 +1,312 @@
|
||||
"""Tests for LangGraph Server auth handler (langgraph_auth.py).
|
||||
|
||||
Validates that the LangGraph auth layer enforces the same rules as Gateway:
|
||||
cookie → JWT decode → DB lookup → token_version check → owner filter
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||
|
||||
_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _setup_auth_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
yield
|
||||
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||
|
||||
|
||||
def _req(cookies=None, method="GET", headers=None):
|
||||
return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {})
|
||||
|
||||
|
||||
def _user(user_id=None, token_version=0):
|
||||
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
|
||||
|
||||
|
||||
def _mock_provider(user=None):
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(return_value=user)
|
||||
return p
|
||||
|
||||
|
||||
# ── @auth.authenticate ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_no_cookie_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req()))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_invalid_jwt_raises_401():
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_expired_jwt_raises_401():
|
||||
token = create_access_token("user-1", expires_delta=timedelta(seconds=-1))
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_user_not_found_raises_401():
|
||||
token = create_access_token("ghost")
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "User not found" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_token_version_mismatch_raises_401():
|
||||
user = _user(token_version=2)
|
||||
token = create_access_token(str(user.id), token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert exc.value.status_code == 401
|
||||
assert "revoked" in str(exc.value.detail).lower()
|
||||
|
||||
|
||||
def test_valid_token_returns_user_id():
|
||||
user = _user(token_version=0)
|
||||
token = create_access_token(str(user.id), token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
def test_valid_token_matching_version():
|
||||
user = _user(token_version=5)
|
||||
token = create_access_token(str(user.id), token_version=5)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||
assert result == str(user.id)
|
||||
|
||||
|
||||
# ── @auth.authenticate edge cases ────────────────────────────────────────
|
||||
|
||||
|
||||
def test_provider_exception_propagates():
|
||||
"""Provider raises → should not be swallowed silently."""
|
||||
token = create_access_token("user-1")
|
||||
p = AsyncMock()
|
||||
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
|
||||
with pytest.raises(RuntimeError, match="DB down"):
|
||||
asyncio.run(authenticate(_req({"access_token": token})))
|
||||
|
||||
|
||||
def test_jwt_missing_ver_defaults_to_zero():
|
||||
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=0)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
result = asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert result == uid
|
||||
|
||||
|
||||
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
|
||||
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
uid = str(uuid4())
|
||||
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||
user = _user(user_id=uid, token_version=1)
|
||||
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_wrong_secret_raises_401():
|
||||
"""Token signed with different secret → 401."""
|
||||
import jwt as pyjwt
|
||||
|
||||
raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256")
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
# ── @auth.on (owner filter) ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeUser:
|
||||
"""Minimal BaseUser-compatible object without langgraph_api.config dependency."""
|
||||
|
||||
def __init__(self, identity: str):
|
||||
self.identity = identity
|
||||
self.is_authenticated = True
|
||||
self.display_name = identity
|
||||
|
||||
|
||||
def _make_ctx(user_id):
|
||||
return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[])
|
||||
|
||||
|
||||
def test_filter_injects_user_id():
|
||||
value = {}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
|
||||
|
||||
def test_filter_preserves_existing_metadata():
|
||||
value = {"metadata": {"title": "hello"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||
assert value["metadata"]["user_id"] == "user-a"
|
||||
assert value["metadata"]["title"] == "hello"
|
||||
|
||||
|
||||
def test_filter_returns_user_id_dict():
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||
assert result == {"user_id": "user-x"}
|
||||
|
||||
|
||||
def test_filter_read_write_consistency():
|
||||
value = {}
|
||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||
assert value["metadata"]["user_id"] == filter_dict["user_id"]
|
||||
|
||||
|
||||
def test_different_users_different_filters():
|
||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||
assert f_a["user_id"] != f_b["user_id"]
|
||||
|
||||
|
||||
def test_filter_overrides_conflicting_user_id():
|
||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||
value = {"metadata": {"user_id": "attacker"}}
|
||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||
assert value["metadata"]["user_id"] == "real-owner"
|
||||
|
||||
|
||||
def test_filter_with_empty_metadata():
|
||||
"""Explicit empty metadata dict is fine."""
|
||||
value = {"metadata": {}}
|
||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||
assert value["metadata"]["user_id"] == "user-z"
|
||||
assert result == {"user_id": "user-z"}
|
||||
|
||||
|
||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_shared_jwt_secret():
|
||||
token = create_access_token("user-1", token_version=3)
|
||||
payload = decode_token(token)
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
assert not isinstance(payload, TokenError)
|
||||
assert payload.sub == "user-1"
|
||||
assert payload.ver == 3
|
||||
|
||||
|
||||
def test_langgraph_json_has_auth_path():
|
||||
import json
|
||||
|
||||
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
|
||||
assert "auth" in config
|
||||
assert "langgraph_auth" in config["auth"]["path"]
|
||||
|
||||
|
||||
def test_auth_handler_has_both_layers():
|
||||
from app.gateway.langgraph_auth import auth
|
||||
|
||||
assert auth._authenticate_handler is not None
|
||||
assert len(auth._global_handlers) == 1
|
||||
|
||||
|
||||
# ── CSRF in LangGraph auth ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_csrf_get_no_check():
|
||||
"""GET requests skip CSRF — should proceed to JWT validation."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="GET")))
|
||||
# Rejected by missing cookie, NOT by CSRF
|
||||
assert exc.value.status_code == 401
|
||||
assert "Not authenticated" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_missing_token():
|
||||
"""POST without CSRF token → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
assert "CSRF token missing" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_mismatched_token():
|
||||
"""POST with mismatched CSRF tokens → 403."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "some-jwt", "csrf_token": "real-token"},
|
||||
headers={"x-csrf-token": "wrong-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
assert exc.value.status_code == 403
|
||||
assert "mismatch" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_post_matching_token_proceeds_to_jwt():
|
||||
"""POST with matching CSRF tokens passes CSRF check, then fails on JWT."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(
|
||||
authenticate(
|
||||
_req(
|
||||
method="POST",
|
||||
cookies={"access_token": "garbage", "csrf_token": "same-token"},
|
||||
headers={"x-csrf-token": "same-token"},
|
||||
)
|
||||
)
|
||||
)
|
||||
# Past CSRF, rejected by JWT decode
|
||||
assert exc.value.status_code == 401
|
||||
assert "Token error" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_csrf_put_requires_token():
|
||||
"""PUT also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
def test_csrf_delete_requires_token():
|
||||
"""DELETE also requires CSRF."""
|
||||
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||
asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"})))
|
||||
assert exc.value.status_code == 403
|
||||
@@ -0,0 +1,388 @@
|
||||
import errno
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
|
||||
from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider
|
||||
|
||||
|
||||
class TestPathMapping:
|
||||
def test_path_mapping_dataclass(self):
|
||||
mapping = PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True)
|
||||
assert mapping.container_path == "/mnt/skills"
|
||||
assert mapping.local_path == "/home/user/skills"
|
||||
assert mapping.read_only is True
|
||||
|
||||
def test_path_mapping_defaults_to_false(self):
|
||||
mapping = PathMapping(container_path="/mnt/data", local_path="/home/user/data")
|
||||
assert mapping.read_only is False
|
||||
|
||||
|
||||
class TestLocalSandboxPathResolution:
|
||||
def test_resolve_path_exact_match(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._resolve_path("/mnt/skills")
|
||||
assert resolved == "/home/user/skills"
|
||||
|
||||
def test_resolve_path_nested_path(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._resolve_path("/mnt/skills/agent/prompt.py")
|
||||
assert resolved == "/home/user/skills/agent/prompt.py"
|
||||
|
||||
def test_resolve_path_no_mapping(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._resolve_path("/mnt/other/file.txt")
|
||||
assert resolved == "/mnt/other/file.txt"
|
||||
|
||||
def test_resolve_path_longest_prefix_first(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
|
||||
PathMapping(container_path="/mnt", local_path="/var/mnt"),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._resolve_path("/mnt/skills/file.py")
|
||||
# Should match /mnt/skills first (longer prefix)
|
||||
assert resolved == "/home/user/skills/file.py"
|
||||
|
||||
def test_reverse_resolve_path_exact_match(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._reverse_resolve_path(str(skills_dir))
|
||||
assert resolved == "/mnt/skills"
|
||||
|
||||
def test_reverse_resolve_path_nested(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
file_path = skills_dir / "agent" / "prompt.py"
|
||||
file_path.parent.mkdir()
|
||||
file_path.write_text("test")
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
|
||||
],
|
||||
)
|
||||
resolved = sandbox._reverse_resolve_path(str(file_path))
|
||||
assert resolved == "/mnt/skills/agent/prompt.py"
|
||||
|
||||
|
||||
class TestReadOnlyPath:
|
||||
def test_is_read_only_true(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
|
||||
],
|
||||
)
|
||||
assert sandbox._is_read_only_path("/home/user/skills/file.py") is True
|
||||
|
||||
def test_is_read_only_false_for_writable(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path="/home/user/data", read_only=False),
|
||||
],
|
||||
)
|
||||
assert sandbox._is_read_only_path("/home/user/data/file.txt") is False
|
||||
|
||||
def test_is_read_only_false_for_unmapped_path(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
|
||||
],
|
||||
)
|
||||
# Path not under any mapping
|
||||
assert sandbox._is_read_only_path("/tmp/other/file.txt") is False
|
||||
|
||||
def test_is_read_only_true_for_exact_match(self):
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
|
||||
],
|
||||
)
|
||||
assert sandbox._is_read_only_path("/home/user/skills") is True
|
||||
|
||||
def test_write_file_blocked_on_read_only(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
|
||||
],
|
||||
)
|
||||
# Skills dir is read-only, write should be blocked
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
sandbox.write_file("/mnt/skills/new_file.py", "content")
|
||||
assert exc_info.value.errno == errno.EROFS
|
||||
|
||||
def test_write_file_allowed_on_writable_mount(self, tmp_path):
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
|
||||
],
|
||||
)
|
||||
sandbox.write_file("/mnt/data/file.txt", "content")
|
||||
assert (data_dir / "file.txt").read_text() == "content"
|
||||
|
||||
def test_update_file_blocked_on_read_only(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
existing_file = skills_dir / "existing.py"
|
||||
existing_file.write_bytes(b"original")
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
|
||||
],
|
||||
)
|
||||
with pytest.raises(OSError) as exc_info:
|
||||
sandbox.update_file("/mnt/skills/existing.py", b"updated")
|
||||
assert exc_info.value.errno == errno.EROFS
|
||||
|
||||
|
||||
class TestMultipleMounts:
|
||||
def test_multiple_read_write_mounts(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
external_dir = tmp_path / "external"
|
||||
external_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
|
||||
PathMapping(container_path="/mnt/external", local_path=str(external_dir), read_only=True),
|
||||
],
|
||||
)
|
||||
|
||||
# Skills is read-only
|
||||
with pytest.raises(OSError):
|
||||
sandbox.write_file("/mnt/skills/file.py", "content")
|
||||
|
||||
# Data is writable
|
||||
sandbox.write_file("/mnt/data/file.txt", "data content")
|
||||
assert (data_dir / "file.txt").read_text() == "data content"
|
||||
|
||||
# External is read-only
|
||||
with pytest.raises(OSError):
|
||||
sandbox.write_file("/mnt/external/file.txt", "content")
|
||||
|
||||
def test_nested_mounts_writable_under_readonly(self, tmp_path):
|
||||
"""A writable mount nested under a read-only mount should allow writes."""
|
||||
ro_dir = tmp_path / "ro"
|
||||
ro_dir.mkdir()
|
||||
rw_dir = ro_dir / "writable"
|
||||
rw_dir.mkdir()
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/repo", local_path=str(ro_dir), read_only=True),
|
||||
PathMapping(container_path="/mnt/repo/writable", local_path=str(rw_dir), read_only=False),
|
||||
],
|
||||
)
|
||||
|
||||
# Parent mount is read-only
|
||||
with pytest.raises(OSError):
|
||||
sandbox.write_file("/mnt/repo/file.txt", "content")
|
||||
|
||||
# Nested writable mount should allow writes
|
||||
sandbox.write_file("/mnt/repo/writable/file.txt", "content")
|
||||
assert (rw_dir / "file.txt").read_text() == "content"
|
||||
|
||||
def test_execute_command_path_replacement(self, tmp_path, monkeypatch):
|
||||
data_dir = tmp_path / "data"
|
||||
data_dir.mkdir()
|
||||
test_file = data_dir / "test.txt"
|
||||
test_file.write_text("hello")
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
|
||||
],
|
||||
)
|
||||
|
||||
# Mock subprocess to capture the resolved command
|
||||
captured = {}
|
||||
original_run = __import__("subprocess").run
|
||||
|
||||
def mock_run(*args, **kwargs):
|
||||
if len(args) > 0:
|
||||
captured["command"] = args[0]
|
||||
return original_run(*args, **kwargs)
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.subprocess.run", mock_run)
|
||||
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.LocalSandbox._get_shell", lambda self: "/bin/sh")
|
||||
|
||||
sandbox.execute_command("cat /mnt/data/test.txt")
|
||||
# Verify the command received the resolved local path
|
||||
assert str(data_dir) in captured.get("command", "")
|
||||
|
||||
def test_reverse_resolve_path_does_not_match_partial_prefix(self, tmp_path):
|
||||
foo_dir = tmp_path / "foo"
|
||||
foo_dir.mkdir()
|
||||
foobar_dir = tmp_path / "foobar"
|
||||
foobar_dir.mkdir()
|
||||
target = foobar_dir / "file.txt"
|
||||
target.write_text("test")
|
||||
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/foo", local_path=str(foo_dir)),
|
||||
],
|
||||
)
|
||||
|
||||
resolved = sandbox._reverse_resolve_path(str(target))
|
||||
assert resolved == str(target.resolve())
|
||||
|
||||
def test_reverse_resolve_paths_in_output_supports_backslash_separator(self, tmp_path):
|
||||
mount_dir = tmp_path / "mount"
|
||||
mount_dir.mkdir()
|
||||
sandbox = LocalSandbox(
|
||||
"test",
|
||||
[
|
||||
PathMapping(container_path="/mnt/data", local_path=str(mount_dir)),
|
||||
],
|
||||
)
|
||||
|
||||
output = f"Copied: {mount_dir}\\file.txt"
|
||||
masked = sandbox._reverse_resolve_paths_in_output(output)
|
||||
|
||||
assert "/mnt/data/file.txt" in masked
|
||||
assert str(mount_dir) not in masked
|
||||
|
||||
|
||||
class TestLocalSandboxProviderMounts:
|
||||
def test_setup_path_mappings_uses_configured_skills_container_path_as_reserved_prefix(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
custom_dir = tmp_path / "custom"
|
||||
custom_dir.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
sandbox_config = SandboxConfig(
|
||||
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||
mounts=[
|
||||
VolumeMountConfig(host_path=str(custom_dir), container_path="/custom-skills/nested", read_only=False),
|
||||
],
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/custom-skills", get_skills_path=lambda: skills_dir),
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
|
||||
|
||||
def test_setup_path_mappings_skips_relative_host_path(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
sandbox_config = SandboxConfig(
|
||||
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||
mounts=[
|
||||
VolumeMountConfig(host_path="relative/path", container_path="/mnt/data", read_only=False),
|
||||
],
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
def test_setup_path_mappings_skips_non_absolute_container_path(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
custom_dir = tmp_path / "custom"
|
||||
custom_dir.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
sandbox_config = SandboxConfig(
|
||||
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||
mounts=[
|
||||
VolumeMountConfig(host_path=str(custom_dir), container_path="mnt/data", read_only=False),
|
||||
],
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
|
||||
|
||||
def test_setup_path_mappings_normalizes_container_path_trailing_slash(self, tmp_path):
|
||||
skills_dir = tmp_path / "skills"
|
||||
skills_dir.mkdir()
|
||||
custom_dir = tmp_path / "custom"
|
||||
custom_dir.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
sandbox_config = SandboxConfig(
|
||||
use="deerflow.sandbox.local:LocalSandboxProvider",
|
||||
mounts=[
|
||||
VolumeMountConfig(host_path=str(custom_dir), container_path="/mnt/data/", read_only=False),
|
||||
],
|
||||
)
|
||||
config = SimpleNamespace(
|
||||
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
|
||||
sandbox=sandbox_config,
|
||||
)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=config):
|
||||
provider = LocalSandboxProvider()
|
||||
|
||||
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Tests for LoopDetectionMiddleware."""
|
||||
|
||||
import copy
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
@@ -19,8 +20,13 @@ def _make_runtime(thread_id="test-thread"):
|
||||
|
||||
|
||||
def _make_state(tool_calls=None, content=""):
|
||||
"""Build a minimal AgentState dict with an AIMessage."""
|
||||
msg = AIMessage(content=content, tool_calls=tool_calls or [])
|
||||
"""Build a minimal AgentState dict with an AIMessage.
|
||||
|
||||
Deep-copies *content* when it is mutable (e.g. list) so that
|
||||
successive calls never share the same object reference.
|
||||
"""
|
||||
safe_content = copy.deepcopy(content) if isinstance(content, list) else content
|
||||
msg = AIMessage(content=safe_content, tool_calls=tool_calls or [])
|
||||
return {"messages": [msg]}
|
||||
|
||||
|
||||
@@ -229,3 +235,114 @@ class TestLoopDetection:
|
||||
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert "default" in mw._history
|
||||
|
||||
|
||||
class TestAppendText:
|
||||
"""Unit tests for LoopDetectionMiddleware._append_text."""
|
||||
|
||||
def test_none_content_returns_text(self):
|
||||
result = LoopDetectionMiddleware._append_text(None, "hello")
|
||||
assert result == "hello"
|
||||
|
||||
def test_str_content_concatenates(self):
|
||||
result = LoopDetectionMiddleware._append_text("existing", "appended")
|
||||
assert result == "existing\n\nappended"
|
||||
|
||||
def test_empty_str_content_concatenates(self):
|
||||
result = LoopDetectionMiddleware._append_text("", "appended")
|
||||
assert result == "\n\nappended"
|
||||
|
||||
def test_list_content_appends_text_block(self):
|
||||
"""List content (e.g. Anthropic thinking mode) should get a new text block."""
|
||||
content = [
|
||||
{"type": "thinking", "text": "Let me think..."},
|
||||
{"type": "text", "text": "Here is my answer"},
|
||||
]
|
||||
result = LoopDetectionMiddleware._append_text(content, "stop msg")
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 3
|
||||
assert result[0] == content[0]
|
||||
assert result[1] == content[1]
|
||||
assert result[2] == {"type": "text", "text": "\n\nstop msg"}
|
||||
|
||||
def test_empty_list_content_appends_text_block(self):
|
||||
result = LoopDetectionMiddleware._append_text([], "stop msg")
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0] == {"type": "text", "text": "\n\nstop msg"}
|
||||
|
||||
def test_unexpected_type_coerced_to_str(self):
|
||||
"""Unexpected content types should be coerced to str as a fallback."""
|
||||
result = LoopDetectionMiddleware._append_text(42, "stop msg")
|
||||
assert isinstance(result, str)
|
||||
assert result == "42\n\nstop msg"
|
||||
|
||||
def test_list_content_not_mutated_in_place(self):
|
||||
"""_append_text must not modify the original list."""
|
||||
original = [{"type": "text", "text": "hello"}]
|
||||
result = LoopDetectionMiddleware._append_text(original, "appended")
|
||||
assert len(original) == 1 # original unchanged
|
||||
assert len(result) == 2 # new list has the appended block
|
||||
|
||||
|
||||
class TestHardStopWithListContent:
|
||||
"""Regression tests: hard stop must not crash when AIMessage.content is a list."""
|
||||
|
||||
def test_hard_stop_with_list_content(self):
|
||||
"""Hard stop on list content should not raise TypeError (regression)."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
# Build state with list content (e.g. Anthropic thinking mode)
|
||||
list_content = [
|
||||
{"type": "thinking", "text": "Let me think..."},
|
||||
{"type": "text", "text": "I'll run ls"},
|
||||
]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
|
||||
|
||||
# Fourth call triggers hard stop — must not raise TypeError
|
||||
result = mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg, AIMessage)
|
||||
assert msg.tool_calls == []
|
||||
# Content should remain a list with the stop message appended
|
||||
assert isinstance(msg.content, list)
|
||||
assert len(msg.content) == 3
|
||||
assert msg.content[2]["type"] == "text"
|
||||
assert _HARD_STOP_MSG in msg.content[2]["text"]
|
||||
|
||||
def test_hard_stop_with_none_content(self):
|
||||
"""Hard stop on None content should produce a plain string."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call), runtime)
|
||||
|
||||
# Fourth call with default empty-string content
|
||||
result = mw._apply(_make_state(tool_calls=call), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg.content, str)
|
||||
assert _HARD_STOP_MSG in msg.content
|
||||
|
||||
def test_hard_stop_with_str_content(self):
|
||||
"""Hard stop on str content should concatenate the stop message."""
|
||||
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
|
||||
runtime = _make_runtime()
|
||||
call = [_bash_call("ls")]
|
||||
|
||||
for _ in range(3):
|
||||
mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
|
||||
|
||||
result = mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
|
||||
assert result is not None
|
||||
msg = result["messages"][0]
|
||||
assert isinstance(msg.content, str)
|
||||
assert msg.content.startswith("thinking...")
|
||||
assert _HARD_STOP_MSG in msg.content
|
||||
|
||||
@@ -154,3 +154,22 @@ def test_format_memory_renders_correction_without_source_error_normally() -> Non
|
||||
|
||||
assert "Use make dev for local development." in result
|
||||
assert "avoid:" not in result
|
||||
|
||||
|
||||
def test_format_memory_includes_long_term_background() -> None:
|
||||
"""longTermBackground in history must be injected into the prompt."""
|
||||
memory_data = {
|
||||
"user": {},
|
||||
"history": {
|
||||
"recentMonths": {"summary": "Recent activity summary"},
|
||||
"earlierContext": {"summary": "Earlier context summary"},
|
||||
"longTermBackground": {"summary": "Core expertise in distributed systems"},
|
||||
},
|
||||
"facts": [],
|
||||
}
|
||||
|
||||
result = format_memory_for_injection(memory_data, max_tokens=2000)
|
||||
|
||||
assert "Background: Core expertise in distributed systems" in result
|
||||
assert "Recent: Recent activity summary" in result
|
||||
assert "Earlier: Earlier context summary" in result
|
||||
|
||||
@@ -47,4 +47,45 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
|
||||
thread_id="thread-1",
|
||||
agent_name="lead_agent",
|
||||
correction_detected=True,
|
||||
reinforcement_detected=False,
|
||||
)
|
||||
|
||||
|
||||
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
|
||||
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
|
||||
|
||||
assert len(queue._queue) == 1
|
||||
assert queue._queue[0].messages == ["second"]
|
||||
assert queue._queue[0].reinforcement_detected is True
|
||||
|
||||
|
||||
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
queue._queue = [
|
||||
ConversationContext(
|
||||
thread_id="thread-1",
|
||||
messages=["conversation"],
|
||||
agent_name="lead_agent",
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
]
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
|
||||
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
|
||||
queue._process_queue()
|
||||
|
||||
mock_updater.update_memory.assert_called_once_with(
|
||||
messages=["conversation"],
|
||||
thread_id="thread-1",
|
||||
agent_name="lead_agent",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=True,
|
||||
)
|
||||
|
||||
@@ -619,3 +619,156 @@ class TestUpdateMemoryStructuredResponse:
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" not in prompt
|
||||
|
||||
|
||||
class TestFactDeduplicationCaseInsensitive:
|
||||
"""Tests that fact deduplication is case-insensitive."""
|
||||
|
||||
def test_duplicate_fact_different_case_not_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_1",
|
||||
"content": "User prefers Python",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2026-01-01T00:00:00Z",
|
||||
"source": "thread-a",
|
||||
},
|
||||
]
|
||||
)
|
||||
# Same fact with different casing should be treated as duplicate
|
||||
update_data = {
|
||||
"factsToRemove": [],
|
||||
"newFacts": [
|
||||
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
# Should still have only 1 fact (duplicate rejected)
|
||||
assert len(result["facts"]) == 1
|
||||
assert result["facts"][0]["content"] == "User prefers Python"
|
||||
|
||||
def test_unique_fact_different_case_and_content_stored(self):
|
||||
updater = MemoryUpdater()
|
||||
current_memory = _make_memory(
|
||||
facts=[
|
||||
{
|
||||
"id": "fact_1",
|
||||
"content": "User prefers Python",
|
||||
"category": "preference",
|
||||
"confidence": 0.9,
|
||||
"createdAt": "2026-01-01T00:00:00Z",
|
||||
"source": "thread-a",
|
||||
},
|
||||
]
|
||||
)
|
||||
update_data = {
|
||||
"factsToRemove": [],
|
||||
"newFacts": [
|
||||
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
|
||||
],
|
||||
}
|
||||
|
||||
with patch(
|
||||
"deerflow.agents.memory.updater.get_memory_config",
|
||||
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
|
||||
):
|
||||
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
|
||||
|
||||
assert len(result["facts"]) == 2
|
||||
|
||||
|
||||
class TestReinforcementHint:
|
||||
"""Tests that reinforcement_detected injects the correct hint into the prompt."""
|
||||
|
||||
@staticmethod
|
||||
def _make_mock_model(json_response: str):
|
||||
model = MagicMock()
|
||||
response = MagicMock()
|
||||
response.content = f"```json\n{json_response}\n```"
|
||||
model.invoke.return_value = response
|
||||
return model
|
||||
|
||||
def test_reinforcement_hint_injected_when_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Yes, exactly! That's what I needed."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Great to hear!"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
def test_reinforcement_hint_absent_when_not_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "Tell me more."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Sure."
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Positive reinforcement signals were detected" not in prompt
|
||||
|
||||
def test_both_hints_present_when_both_detected(self):
|
||||
updater = MemoryUpdater()
|
||||
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
|
||||
model = self._make_mock_model(valid_json)
|
||||
|
||||
with (
|
||||
patch.object(updater, "_get_model", return_value=model),
|
||||
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
|
||||
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
|
||||
):
|
||||
msg = MagicMock()
|
||||
msg.type = "human"
|
||||
msg.content = "No wait, that's wrong. Actually yes, exactly right."
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.type = "ai"
|
||||
ai_msg.content = "Got it."
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
|
||||
|
||||
assert result is True
|
||||
prompt = model.invoke.call_args[0][0]
|
||||
assert "Explicit correction signals were detected" in prompt
|
||||
assert "Positive reinforcement signals were detected" in prompt
|
||||
|
||||
@@ -10,7 +10,7 @@ persisting in long-term memory:
|
||||
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
|
||||
|
||||
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
|
||||
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
@@ -270,3 +270,73 @@ class TestStripUploadMentionsFromMemory:
|
||||
mem = {"user": {}, "history": {}, "facts": []}
|
||||
result = _strip_upload_mentions_from_memory(mem)
|
||||
assert result == {"user": {}, "history": {}, "facts": []}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# detect_reinforcement
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
class TestDetectReinforcement:
|
||||
def test_detects_english_reinforcement_signal(self):
|
||||
msgs = [
|
||||
_human("Can you summarise it in bullet points?"),
|
||||
_ai("Here are the key points: ..."),
|
||||
_human("Yes, exactly! That's what I needed."),
|
||||
_ai("Glad it helped."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_detects_perfect_signal(self):
|
||||
msgs = [
|
||||
_human("Write it more concisely."),
|
||||
_ai("Here is the concise version."),
|
||||
_human("Perfect."),
|
||||
_ai("Great!"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_detects_chinese_reinforcement_signal(self):
|
||||
msgs = [
|
||||
_human("帮我用要点来总结"),
|
||||
_ai("好的,要点如下:..."),
|
||||
_human("完全正确,就是这个意思"),
|
||||
_ai("很高兴能帮到你"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is True
|
||||
|
||||
def test_returns_false_without_signal(self):
|
||||
msgs = [
|
||||
_human("What does this function do?"),
|
||||
_ai("It processes the input data."),
|
||||
_human("Can you show me an example?"),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
def test_only_checks_recent_messages(self):
|
||||
# Reinforcement signal buried beyond the -6 window should not trigger
|
||||
msgs = [
|
||||
_human("Yes, exactly right."),
|
||||
_ai("Noted."),
|
||||
_human("Let's discuss tests."),
|
||||
_ai("Sure."),
|
||||
_human("What about linting?"),
|
||||
_ai("Use ruff."),
|
||||
_human("And formatting?"),
|
||||
_ai("Use make format."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
def test_does_not_conflict_with_correction(self):
|
||||
# A message can trigger correction but not reinforcement
|
||||
msgs = [
|
||||
_human("That's wrong, try again."),
|
||||
_ai("Corrected."),
|
||||
]
|
||||
|
||||
assert detect_reinforcement(msgs) is False
|
||||
|
||||
@@ -0,0 +1,393 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
|
||||
from deerflow.sandbox.local.local_sandbox import LocalSandbox
|
||||
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
|
||||
from deerflow.sandbox.tools import glob_tool, grep_tool
|
||||
|
||||
|
||||
def _make_runtime(tmp_path):
|
||||
workspace = tmp_path / "workspace"
|
||||
uploads = tmp_path / "uploads"
|
||||
outputs = tmp_path / "outputs"
|
||||
workspace.mkdir()
|
||||
uploads.mkdir()
|
||||
outputs.mkdir()
|
||||
return SimpleNamespace(
|
||||
state={
|
||||
"sandbox": {"sandbox_id": "local"},
|
||||
"thread_data": {
|
||||
"workspace_path": str(workspace),
|
||||
"uploads_path": str(uploads),
|
||||
"outputs_path": str(outputs),
|
||||
},
|
||||
},
|
||||
context={"thread_id": "thread-1"},
|
||||
)
|
||||
|
||||
|
||||
def test_glob_tool_returns_virtual_paths_and_ignores_common_dirs(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "app.py").write_text("print('hi')\n", encoding="utf-8")
|
||||
(workspace / "pkg").mkdir()
|
||||
(workspace / "pkg" / "util.py").write_text("print('util')\n", encoding="utf-8")
|
||||
(workspace / "node_modules").mkdir()
|
||||
(workspace / "node_modules" / "skip.py").write_text("ignored\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
description="find python files",
|
||||
pattern="**/*.py",
|
||||
path="/mnt/user-data/workspace",
|
||||
)
|
||||
|
||||
assert "/mnt/user-data/workspace/app.py" in result
|
||||
assert "/mnt/user-data/workspace/pkg/util.py" in result
|
||||
assert "node_modules" not in result
|
||||
assert str(workspace) not in result
|
||||
|
||||
|
||||
def test_glob_tool_supports_skills_virtual_paths(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
skills_dir = tmp_path / "skills"
|
||||
(skills_dir / "public" / "demo").mkdir(parents=True)
|
||||
(skills_dir / "public" / "demo" / "SKILL.md").write_text("# Demo\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
with (
|
||||
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
|
||||
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=str(skills_dir)),
|
||||
):
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
description="find skills",
|
||||
pattern="**/SKILL.md",
|
||||
path="/mnt/skills",
|
||||
)
|
||||
|
||||
assert "/mnt/skills/public/demo/SKILL.md" in result
|
||||
assert str(skills_dir) not in result
|
||||
|
||||
|
||||
def test_grep_tool_filters_by_glob_and_skips_binary_files(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "main.py").write_text("TODO = 'ship it'\nprint(TODO)\n", encoding="utf-8")
|
||||
(workspace / "notes.txt").write_text("TODO in txt should be filtered\n", encoding="utf-8")
|
||||
(workspace / "image.bin").write_bytes(b"\0binary TODO")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="find todo references",
|
||||
pattern="TODO",
|
||||
path="/mnt/user-data/workspace",
|
||||
glob="**/*.py",
|
||||
)
|
||||
|
||||
assert "/mnt/user-data/workspace/main.py:1: TODO = 'ship it'" in result
|
||||
assert "notes.txt" not in result
|
||||
assert "image.bin" not in result
|
||||
assert str(workspace) not in result
|
||||
|
||||
|
||||
def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "main.py").write_text("TODO one\nTODO two\nTODO three\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="limit matches",
|
||||
pattern="TODO",
|
||||
path="/mnt/user-data/workspace",
|
||||
max_results=2,
|
||||
)
|
||||
|
||||
assert "Found 2 matches under /mnt/user-data/workspace (showing first 2)" in result
|
||||
assert "TODO one" in result
|
||||
assert "TODO two" in result
|
||||
assert "TODO three" not in result
|
||||
assert "Results truncated." in result
|
||||
|
||||
|
||||
def test_glob_tool_include_dirs_filters_nested_ignored_paths(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "src").mkdir()
|
||||
(workspace / "src" / "main.py").write_text("x\n", encoding="utf-8")
|
||||
(workspace / "node_modules").mkdir()
|
||||
(workspace / "node_modules" / "lib").mkdir()
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
description="find dirs",
|
||||
pattern="**",
|
||||
path="/mnt/user-data/workspace",
|
||||
include_dirs=True,
|
||||
)
|
||||
|
||||
assert "src" in result
|
||||
assert "node_modules" not in result
|
||||
|
||||
|
||||
def test_grep_tool_literal_mode(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "file.py").write_text("price = (a+b)\nresult = a+b\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
# literal=True should treat (a+b) as a plain string, not a regex group
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="literal search",
|
||||
pattern="(a+b)",
|
||||
path="/mnt/user-data/workspace",
|
||||
literal=True,
|
||||
)
|
||||
|
||||
assert "price = (a+b)" in result
|
||||
assert "result = a+b" not in result
|
||||
|
||||
|
||||
def test_grep_tool_case_sensitive(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "file.py").write_text("TODO: fix\ntodo: also fix\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="case sensitive search",
|
||||
pattern="TODO",
|
||||
path="/mnt/user-data/workspace",
|
||||
case_sensitive=True,
|
||||
)
|
||||
|
||||
assert "TODO: fix" in result
|
||||
assert "todo: also fix" not in result
|
||||
|
||||
|
||||
def test_grep_tool_invalid_regex_returns_error(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
|
||||
result = grep_tool.func(
|
||||
runtime=runtime,
|
||||
description="bad pattern",
|
||||
pattern="[invalid",
|
||||
path="/mnt/user-data/workspace",
|
||||
)
|
||||
|
||||
assert "Invalid regex pattern" in result
|
||||
|
||||
|
||||
def test_aio_sandbox_glob_include_dirs_filters_nested_ignored(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"list_path",
|
||||
lambda **kwargs: SimpleNamespace(
|
||||
data=SimpleNamespace(
|
||||
files=[
|
||||
SimpleNamespace(name="src", path="/mnt/workspace/src"),
|
||||
SimpleNamespace(name="node_modules", path="/mnt/workspace/node_modules"),
|
||||
# child of node_modules — should be filtered via should_ignore_path
|
||||
SimpleNamespace(name="lib", path="/mnt/workspace/node_modules/lib"),
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
|
||||
|
||||
assert "/mnt/workspace/src" in matches
|
||||
assert "/mnt/workspace/node_modules" not in matches
|
||||
assert "/mnt/workspace/node_modules/lib" not in matches
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_aio_sandbox_grep_invalid_regex_raises() -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
|
||||
import re
|
||||
|
||||
try:
|
||||
sandbox.grep("/mnt/workspace", "[invalid")
|
||||
assert False, "Expected re.error"
|
||||
except re.error:
|
||||
pass
|
||||
|
||||
|
||||
def test_aio_sandbox_glob_parses_json(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"find_files",
|
||||
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(files=["/mnt/user-data/workspace/app.py", "/mnt/user-data/workspace/node_modules/skip.py"])),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.glob("/mnt/user-data/workspace", "**/*.py")
|
||||
|
||||
assert matches == ["/mnt/user-data/workspace/app.py"]
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_aio_sandbox_grep_parses_json(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"list_path",
|
||||
lambda **kwargs: SimpleNamespace(
|
||||
data=SimpleNamespace(
|
||||
files=[
|
||||
SimpleNamespace(
|
||||
name="app.py",
|
||||
path="/mnt/user-data/workspace/app.py",
|
||||
is_directory=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"search_in_file",
|
||||
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True"])),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
|
||||
|
||||
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_find_glob_matches_raises_not_a_directory(tmp_path) -> None:
|
||||
file_path = tmp_path / "file.txt"
|
||||
file_path.write_text("x\n", encoding="utf-8")
|
||||
|
||||
try:
|
||||
find_glob_matches(file_path, "**/*.py")
|
||||
assert False, "Expected NotADirectoryError"
|
||||
except NotADirectoryError:
|
||||
pass
|
||||
|
||||
|
||||
def test_find_grep_matches_raises_not_a_directory(tmp_path) -> None:
|
||||
file_path = tmp_path / "file.txt"
|
||||
file_path.write_text("TODO\n", encoding="utf-8")
|
||||
|
||||
try:
|
||||
find_grep_matches(file_path, "TODO")
|
||||
assert False, "Expected NotADirectoryError"
|
||||
except NotADirectoryError:
|
||||
pass
|
||||
|
||||
|
||||
def test_find_grep_matches_skips_symlink_outside_root(tmp_path) -> None:
|
||||
workspace = tmp_path / "workspace"
|
||||
workspace.mkdir()
|
||||
outside = tmp_path / "outside.txt"
|
||||
outside.write_text("TODO outside\n", encoding="utf-8")
|
||||
(workspace / "outside-link.txt").symlink_to(outside)
|
||||
|
||||
matches, truncated = find_grep_matches(workspace, "TODO")
|
||||
|
||||
assert matches == []
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -> None:
|
||||
runtime = _make_runtime(tmp_path)
|
||||
workspace = tmp_path / "workspace"
|
||||
(workspace / "a.py").write_text("print('a')\n", encoding="utf-8")
|
||||
(workspace / "b.py").write_text("print('b')\n", encoding="utf-8")
|
||||
(workspace / "c.py").write_text("print('c')\n", encoding="utf-8")
|
||||
|
||||
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
|
||||
monkeypatch.setattr(
|
||||
"deerflow.sandbox.tools.get_app_config",
|
||||
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
|
||||
)
|
||||
|
||||
result = glob_tool.func(
|
||||
runtime=runtime,
|
||||
description="limit glob matches",
|
||||
pattern="**/*.py",
|
||||
path="/mnt/user-data/workspace",
|
||||
max_results=2,
|
||||
)
|
||||
|
||||
assert "Found 2 paths under /mnt/user-data/workspace (showing first 2)" in result
|
||||
assert "Results truncated." in result
|
||||
|
||||
|
||||
def test_aio_sandbox_glob_include_dirs_enforces_root_boundary(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"list_path",
|
||||
lambda **kwargs: SimpleNamespace(
|
||||
data=SimpleNamespace(
|
||||
files=[
|
||||
SimpleNamespace(name="src", path="/mnt/workspace/src"),
|
||||
SimpleNamespace(name="src2", path="/mnt/workspace2/src2"),
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
|
||||
|
||||
assert matches == ["/mnt/workspace/src"]
|
||||
assert truncated is False
|
||||
|
||||
|
||||
def test_aio_sandbox_grep_skips_mismatched_line_number_payloads(monkeypatch) -> None:
|
||||
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
|
||||
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"list_path",
|
||||
lambda **kwargs: SimpleNamespace(
|
||||
data=SimpleNamespace(
|
||||
files=[
|
||||
SimpleNamespace(
|
||||
name="app.py",
|
||||
path="/mnt/user-data/workspace/app.py",
|
||||
is_directory=False,
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
sandbox._client.file,
|
||||
"search_in_file",
|
||||
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True", "extra"])),
|
||||
)
|
||||
|
||||
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
|
||||
|
||||
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
|
||||
assert truncated is False
|
||||
@@ -8,7 +8,10 @@ import pytest
|
||||
from deerflow.sandbox.tools import (
|
||||
VIRTUAL_PATH_PREFIX,
|
||||
_apply_cwd_prefix,
|
||||
_get_custom_mount_for_path,
|
||||
_get_custom_mounts,
|
||||
_is_acp_workspace_path,
|
||||
_is_custom_mount_path,
|
||||
_is_skills_path,
|
||||
_reject_path_traversal,
|
||||
_resolve_acp_workspace_path,
|
||||
@@ -39,6 +42,53 @@ def test_replace_virtual_path_maps_virtual_root_and_subpaths() -> None:
|
||||
assert Path(replace_virtual_path("/mnt/user-data", _THREAD_DATA)).as_posix() == "/tmp/deer-flow/threads/t1/user-data"
|
||||
|
||||
|
||||
def test_replace_virtual_path_preserves_trailing_slash() -> None:
|
||||
"""Trailing slash must survive virtual-to-actual path replacement.
|
||||
|
||||
Regression: '/mnt/user-data/workspace/' was previously returned without
|
||||
the trailing slash, causing string concatenations like
|
||||
output_dir + 'file.txt' to produce a missing-separator path.
|
||||
"""
|
||||
result = replace_virtual_path("/mnt/user-data/workspace/", _THREAD_DATA)
|
||||
assert result.endswith("/"), f"Expected trailing slash, got: {result!r}"
|
||||
assert result == "/tmp/deer-flow/threads/t1/user-data/workspace/"
|
||||
|
||||
|
||||
def test_replace_virtual_path_preserves_trailing_slash_windows_style() -> None:
|
||||
"""Trailing slash must be preserved as backslash when actual_base is Windows-style.
|
||||
|
||||
If actual_base uses backslash separators, appending '/' would produce a
|
||||
mixed-separator path. The separator must match the style of actual_base.
|
||||
"""
|
||||
win_thread_data = {
|
||||
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
|
||||
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
|
||||
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
|
||||
}
|
||||
result = replace_virtual_path("/mnt/user-data/workspace/", win_thread_data)
|
||||
assert result.endswith("\\"), f"Expected trailing backslash for Windows path, got: {result!r}"
|
||||
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
|
||||
|
||||
|
||||
def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing_slash() -> None:
|
||||
"""Nested Windows-style subdirectories must keep backslashes throughout."""
|
||||
win_thread_data = {
|
||||
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
|
||||
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
|
||||
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
|
||||
}
|
||||
result = replace_virtual_path("/mnt/user-data/workspace/subdir/", win_thread_data)
|
||||
assert result == "C:\\deer-flow\\threads\\t1\\user-data\\workspace\\subdir\\"
|
||||
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
|
||||
|
||||
|
||||
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
|
||||
"""Trailing slash on a virtual path inside a command must be preserved."""
|
||||
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
|
||||
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
|
||||
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
|
||||
|
||||
|
||||
# ---------- mask_local_paths_in_output ----------
|
||||
|
||||
|
||||
@@ -96,6 +146,25 @@ def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
|
||||
with pytest.raises(PermissionError, match="configured mount paths"):
|
||||
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
|
||||
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
|
||||
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
|
||||
with pytest.raises(PermissionError, match="Only paths under"):
|
||||
@@ -235,6 +304,22 @@ def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
|
||||
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
|
||||
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
|
||||
"""HTTP URLs must not be flagged as unsafe absolute paths."""
|
||||
validate_local_bash_command_paths(
|
||||
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
|
||||
_THREAD_DATA,
|
||||
)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
|
||||
validate_local_bash_command_paths(
|
||||
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
|
||||
@@ -567,6 +652,156 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
|
||||
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
|
||||
|
||||
|
||||
# ---------- Custom mount path tests ----------
|
||||
|
||||
|
||||
def _mock_custom_mounts():
|
||||
"""Create mock VolumeMountConfig objects for testing."""
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
return [
|
||||
VolumeMountConfig(host_path="/home/user/code-read", container_path="/mnt/code-read", read_only=True),
|
||||
VolumeMountConfig(host_path="/home/user/data", container_path="/mnt/data", read_only=False),
|
||||
]
|
||||
|
||||
|
||||
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
assert _is_custom_mount_path("/mnt/code-read") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
|
||||
assert _is_custom_mount_path("/mnt/data") is True
|
||||
assert _is_custom_mount_path("/mnt/data/file.txt") is True
|
||||
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
|
||||
assert _is_custom_mount_path("/mnt/other") is False
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
|
||||
from deerflow.config.sandbox_config import VolumeMountConfig
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path="/var/mnt", container_path="/mnt", read_only=False),
|
||||
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
|
||||
]
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
|
||||
mount = _get_custom_mount_for_path("/mnt/code/file.py")
|
||||
assert mount is not None
|
||||
assert mount.container_path == "/mnt/code"
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
|
||||
"""read_file / ls should be able to access custom mount paths."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
|
||||
"""write_file / str_replace must NOT write to read-only custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
|
||||
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
|
||||
"""write_file / str_replace should succeed on writable custom mounts."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
|
||||
|
||||
|
||||
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Path traversal via .. in custom mount paths must be rejected."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
|
||||
"""bash commands referencing custom mount paths should be allowed."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
|
||||
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
|
||||
"""Bash commands with traversal in custom mount paths should be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="path traversal"):
|
||||
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
|
||||
"""Paths not matching any custom mount should still be blocked."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
|
||||
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
|
||||
|
||||
|
||||
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
|
||||
"""_get_custom_mounts should cache after first successful load."""
|
||||
# Clear any existing cache
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
# Use real directories so host_path.exists() filtering passes
|
||||
dir_a = tmp_path / "code-read"
|
||||
dir_a.mkdir()
|
||||
dir_b = tmp_path / "data"
|
||||
dir_b.mkdir()
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
|
||||
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 2
|
||||
|
||||
# After caching, should return cached value even without mock
|
||||
assert hasattr(_get_custom_mounts, "_cached")
|
||||
assert len(_get_custom_mounts()) == 2
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
|
||||
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
|
||||
"""_get_custom_mounts should only return mounts whose host_path exists."""
|
||||
if hasattr(_get_custom_mounts, "_cached"):
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
|
||||
|
||||
existing_dir = tmp_path / "existing"
|
||||
existing_dir.mkdir()
|
||||
|
||||
mounts = [
|
||||
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
|
||||
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
|
||||
]
|
||||
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
|
||||
mock_config = SimpleNamespace(sandbox=mock_sandbox)
|
||||
|
||||
with patch("deerflow.config.get_app_config", return_value=mock_config):
|
||||
result = _get_custom_mounts()
|
||||
assert len(result) == 1
|
||||
assert result[0].container_path == "/mnt/existing"
|
||||
|
||||
# Cleanup
|
||||
monkeypatch.delattr(_get_custom_mounts, "_cached")
|
||||
|
||||
|
||||
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
|
||||
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
|
||||
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
|
||||
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
|
||||
assert mount is None
|
||||
|
||||
|
||||
def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) -> None:
|
||||
class SharedSandbox:
|
||||
def __init__(self) -> None:
|
||||
|
||||
@@ -140,6 +140,193 @@ async def test_event_id_format(bridge: MemoryStreamBridge):
|
||||
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# END sentinel guarantee tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_delivered_when_queue_full():
|
||||
"""END sentinel must always be delivered, even when the queue is completely full.
|
||||
|
||||
This is the critical regression test for the bug where publish_end()
|
||||
would silently drop the END sentinel when the queue was full, causing
|
||||
subscribe() to hang forever and leaking resources.
|
||||
"""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=2)
|
||||
run_id = "run-end-full"
|
||||
|
||||
# Fill the queue to capacity
|
||||
await bridge.publish(run_id, "event-1", {"n": 1})
|
||||
await bridge.publish(run_id, "event-2", {"n": 2})
|
||||
assert bridge._queues[run_id].full()
|
||||
|
||||
# publish_end should succeed by evicting old events
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# Subscriber must receive END_SENTINEL
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert any(e is END_SENTINEL for e in events), "END sentinel was not delivered"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_evicts_oldest_events():
|
||||
"""When queue is full, publish_end evicts the oldest events to make room."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-evict"
|
||||
|
||||
# Fill queue with one event
|
||||
await bridge.publish(run_id, "will-be-evicted", {})
|
||||
assert bridge._queues[run_id].full()
|
||||
|
||||
# publish_end must succeed
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# The only event we should get is END_SENTINEL (the regular event was evicted)
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
assert len(events) == 1
|
||||
assert events[0] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_end_sentinel_no_eviction_when_space_available():
|
||||
"""When queue has space, publish_end should not evict anything."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=10)
|
||||
run_id = "run-no-evict"
|
||||
|
||||
await bridge.publish(run_id, "event-1", {"n": 1})
|
||||
await bridge.publish(run_id, "event-2", {"n": 2})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
break
|
||||
|
||||
# All events plus END should be present
|
||||
assert len(events) == 3
|
||||
assert events[0].event == "event-1"
|
||||
assert events[1].event == "event-2"
|
||||
assert events[2] is END_SENTINEL
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_concurrent_tasks_end_sentinel():
|
||||
"""Multiple concurrent producer/consumer pairs should all terminate properly.
|
||||
|
||||
Simulates the production scenario where multiple runs share a single
|
||||
bridge instance — each must receive its own END sentinel.
|
||||
"""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=4)
|
||||
num_runs = 4
|
||||
|
||||
async def producer(run_id: str):
|
||||
for i in range(10): # More events than queue capacity
|
||||
await bridge.publish(run_id, f"event-{i}", {"i": i})
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
async def consumer(run_id: str) -> list:
|
||||
events = []
|
||||
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
|
||||
events.append(entry)
|
||||
if entry is END_SENTINEL:
|
||||
return events
|
||||
return events # pragma: no cover
|
||||
|
||||
# Run producers and consumers concurrently
|
||||
run_ids = [f"concurrent-{i}" for i in range(num_runs)]
|
||||
producers = [producer(rid) for rid in run_ids]
|
||||
consumers = [consumer(rid) for rid in run_ids]
|
||||
|
||||
# Start consumers first, then producers
|
||||
consumer_tasks = [asyncio.create_task(c) for c in consumers]
|
||||
await asyncio.gather(*producers)
|
||||
|
||||
results = await asyncio.wait_for(
|
||||
asyncio.gather(*consumer_tasks),
|
||||
timeout=10.0,
|
||||
)
|
||||
|
||||
for i, events in enumerate(results):
|
||||
assert events[-1] is END_SENTINEL, f"Run {run_ids[i]} did not receive END sentinel"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Drop counter tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dropped_count_tracking():
|
||||
"""Dropped events should be tracked per run_id."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=1)
|
||||
run_id = "run-drop-count"
|
||||
|
||||
# Fill the queue
|
||||
await bridge.publish(run_id, "first", {})
|
||||
|
||||
# This publish will time out and be dropped (we patch timeout to be instant)
|
||||
# Instead, we verify the counter after publish_end eviction
|
||||
await bridge.publish_end(run_id)
|
||||
|
||||
# dropped_count tracks publish() drops, not publish_end evictions
|
||||
assert bridge.dropped_count(run_id) == 0
|
||||
|
||||
# cleanup should also clear the counter
|
||||
await bridge.cleanup(run_id)
|
||||
assert bridge.dropped_count(run_id) == 0
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_dropped_total():
|
||||
"""dropped_total should sum across all runs."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
|
||||
# No drops yet
|
||||
assert bridge.dropped_total == 0
|
||||
|
||||
# Manually set some counts to verify the property
|
||||
bridge._dropped_counts["run-a"] = 3
|
||||
bridge._dropped_counts["run-b"] = 7
|
||||
assert bridge.dropped_total == 10
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cleanup_clears_dropped_counts():
|
||||
"""cleanup() should clear the dropped counter for the run."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
run_id = "run-cleanup-drops"
|
||||
|
||||
bridge._get_or_create_queue(run_id)
|
||||
bridge._dropped_counts[run_id] = 5
|
||||
|
||||
await bridge.cleanup(run_id)
|
||||
assert run_id not in bridge._dropped_counts
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_close_clears_dropped_counts():
|
||||
"""close() should clear all dropped counters."""
|
||||
bridge = MemoryStreamBridge(queue_maxsize=256)
|
||||
bridge._dropped_counts["run-x"] = 10
|
||||
bridge._dropped_counts["run-y"] = 20
|
||||
|
||||
await bridge.close()
|
||||
assert bridge.dropped_total == 0
|
||||
assert len(bridge._dropped_counts) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Factory tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""Tests for subagent timeout configuration.
|
||||
"""Tests for subagent runtime configuration.
|
||||
|
||||
Covers:
|
||||
- SubagentsAppConfig / SubagentOverrideConfig model validation and defaults
|
||||
- get_timeout_for() resolution logic (global vs per-agent)
|
||||
- get_timeout_for() / get_max_turns_for() resolution logic
|
||||
- load_subagents_config_from_dict() and get_subagents_app_config() singleton
|
||||
- registry.get_subagent_config() applies config overrides
|
||||
- registry.list_subagents() applies overrides for all agents
|
||||
@@ -24,9 +24,20 @@ from deerflow.subagents.config import SubagentConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _reset_subagents_config(timeout_seconds: int = 900, agents: dict | None = None) -> None:
|
||||
def _reset_subagents_config(
|
||||
timeout_seconds: int = 900,
|
||||
*,
|
||||
max_turns: int | None = None,
|
||||
agents: dict | None = None,
|
||||
) -> None:
|
||||
"""Reset global subagents config to a known state."""
|
||||
load_subagents_config_from_dict({"timeout_seconds": timeout_seconds, "agents": agents or {}})
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": timeout_seconds,
|
||||
"max_turns": max_turns,
|
||||
"agents": agents or {},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -38,22 +49,29 @@ class TestSubagentOverrideConfig:
|
||||
def test_default_is_none(self):
|
||||
override = SubagentOverrideConfig()
|
||||
assert override.timeout_seconds is None
|
||||
assert override.max_turns is None
|
||||
|
||||
def test_explicit_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=300)
|
||||
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42)
|
||||
assert override.timeout_seconds == 300
|
||||
assert override.max_turns == 42
|
||||
|
||||
def test_rejects_zero(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(timeout_seconds=0)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(max_turns=0)
|
||||
|
||||
def test_rejects_negative(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(timeout_seconds=-1)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentOverrideConfig(max_turns=-1)
|
||||
|
||||
def test_minimum_valid_value(self):
|
||||
override = SubagentOverrideConfig(timeout_seconds=1)
|
||||
override = SubagentOverrideConfig(timeout_seconds=1, max_turns=1)
|
||||
assert override.timeout_seconds == 1
|
||||
assert override.max_turns == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -66,66 +84,86 @@ class TestSubagentsAppConfigDefaults:
|
||||
config = SubagentsAppConfig()
|
||||
assert config.timeout_seconds == 900
|
||||
|
||||
def test_default_max_turns_override_is_none(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.max_turns is None
|
||||
|
||||
def test_default_agents_empty(self):
|
||||
config = SubagentsAppConfig()
|
||||
assert config.agents == {}
|
||||
|
||||
def test_custom_global_timeout(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=1800)
|
||||
def test_custom_global_runtime_overrides(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=1800, max_turns=120)
|
||||
assert config.timeout_seconds == 1800
|
||||
assert config.max_turns == 120
|
||||
|
||||
def test_rejects_zero_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(timeout_seconds=0)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(max_turns=0)
|
||||
|
||||
def test_rejects_negative_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(timeout_seconds=-60)
|
||||
with pytest.raises(ValueError):
|
||||
SubagentsAppConfig(max_turns=-60)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SubagentsAppConfig.get_timeout_for()
|
||||
# SubagentsAppConfig resolution helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetTimeoutFor:
|
||||
class TestRuntimeResolution:
|
||||
def test_returns_global_default_when_no_override(self):
|
||||
config = SubagentsAppConfig(timeout_seconds=600)
|
||||
assert config.get_timeout_for("general-purpose") == 600
|
||||
assert config.get_timeout_for("bash") == 600
|
||||
assert config.get_timeout_for("unknown-agent") == 600
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 100
|
||||
assert config.get_max_turns_for("bash", 60) == 60
|
||||
|
||||
def test_returns_per_agent_override_when_set(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300)},
|
||||
max_turns=120,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
|
||||
)
|
||||
assert config.get_timeout_for("bash") == 300
|
||||
assert config.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
def test_other_agents_still_use_global_default(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300)},
|
||||
max_turns=140,
|
||||
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 900
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 140
|
||||
|
||||
def test_agent_with_none_override_falls_back_to_global(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None)},
|
||||
max_turns=150,
|
||||
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None, max_turns=None)},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 900
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 150
|
||||
|
||||
def test_multiple_per_agent_overrides(self):
|
||||
config = SubagentsAppConfig(
|
||||
timeout_seconds=900,
|
||||
max_turns=120,
|
||||
agents={
|
||||
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800),
|
||||
"bash": SubagentOverrideConfig(timeout_seconds=120),
|
||||
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800, max_turns=200),
|
||||
"bash": SubagentOverrideConfig(timeout_seconds=120, max_turns=80),
|
||||
},
|
||||
)
|
||||
assert config.get_timeout_for("general-purpose") == 1800
|
||||
assert config.get_timeout_for("bash") == 120
|
||||
assert config.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert config.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -139,54 +177,63 @@ class TestLoadSubagentsConfig:
|
||||
_reset_subagents_config()
|
||||
|
||||
def test_load_global_timeout(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 300})
|
||||
load_subagents_config_from_dict({"timeout_seconds": 300, "max_turns": 120})
|
||||
assert get_subagents_app_config().timeout_seconds == 300
|
||||
assert get_subagents_app_config().max_turns == 120
|
||||
|
||||
def test_load_with_per_agent_overrides(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {
|
||||
"general-purpose": {"timeout_seconds": 1800},
|
||||
"bash": {"timeout_seconds": 60},
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_timeout_for("general-purpose") == 1800
|
||||
assert cfg.get_timeout_for("bash") == 60
|
||||
assert cfg.get_max_turns_for("general-purpose", 100) == 200
|
||||
assert cfg.get_max_turns_for("bash", 60) == 80
|
||||
|
||||
def test_load_partial_override(self):
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 600,
|
||||
"agents": {"bash": {"timeout_seconds": 120}},
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 70}},
|
||||
}
|
||||
)
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.get_timeout_for("general-purpose") == 600
|
||||
assert cfg.get_timeout_for("bash") == 120
|
||||
assert cfg.get_max_turns_for("general-purpose", 100) == 100
|
||||
assert cfg.get_max_turns_for("bash", 60) == 70
|
||||
|
||||
def test_load_empty_dict_uses_defaults(self):
|
||||
load_subagents_config_from_dict({})
|
||||
cfg = get_subagents_app_config()
|
||||
assert cfg.timeout_seconds == 900
|
||||
assert cfg.max_turns is None
|
||||
assert cfg.agents == {}
|
||||
|
||||
def test_load_replaces_previous_config(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 100})
|
||||
load_subagents_config_from_dict({"timeout_seconds": 100, "max_turns": 90})
|
||||
assert get_subagents_app_config().timeout_seconds == 100
|
||||
assert get_subagents_app_config().max_turns == 90
|
||||
|
||||
load_subagents_config_from_dict({"timeout_seconds": 200})
|
||||
load_subagents_config_from_dict({"timeout_seconds": 200, "max_turns": 110})
|
||||
assert get_subagents_app_config().timeout_seconds == 200
|
||||
assert get_subagents_app_config().max_turns == 110
|
||||
|
||||
def test_singleton_returns_same_instance_between_calls(self):
|
||||
load_subagents_config_from_dict({"timeout_seconds": 777})
|
||||
load_subagents_config_from_dict({"timeout_seconds": 777, "max_turns": 123})
|
||||
assert get_subagents_app_config() is get_subagents_app_config()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# registry.get_subagent_config – timeout override applied
|
||||
# registry.get_subagent_config – runtime overrides applied
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -211,25 +258,29 @@ class TestRegistryGetSubagentConfig:
|
||||
_reset_subagents_config(timeout_seconds=900)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.timeout_seconds == 900
|
||||
assert config.max_turns == 100
|
||||
|
||||
def test_global_timeout_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=1800)
|
||||
_reset_subagents_config(timeout_seconds=1800, max_turns=140)
|
||||
config = get_subagent_config("general-purpose")
|
||||
assert config.timeout_seconds == 1800
|
||||
assert config.max_turns == 140
|
||||
|
||||
def test_per_agent_timeout_override_applied(self):
|
||||
def test_per_agent_runtime_override_applied(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"timeout_seconds": 120}},
|
||||
"max_turns": 120,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
|
||||
}
|
||||
)
|
||||
bash_config = get_subagent_config("bash")
|
||||
assert bash_config.timeout_seconds == 120
|
||||
assert bash_config.max_turns == 80
|
||||
|
||||
def test_per_agent_override_does_not_affect_other_agents(self):
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
@@ -237,11 +288,13 @@ class TestRegistryGetSubagentConfig:
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"agents": {"bash": {"timeout_seconds": 120}},
|
||||
"max_turns": 120,
|
||||
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
|
||||
}
|
||||
)
|
||||
gp_config = get_subagent_config("general-purpose")
|
||||
assert gp_config.timeout_seconds == 900
|
||||
assert gp_config.max_turns == 120
|
||||
|
||||
def test_builtin_config_object_is_not_mutated(self):
|
||||
"""Registry must return a new object, leaving the builtin default intact."""
|
||||
@@ -249,24 +302,27 @@ class TestRegistryGetSubagentConfig:
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
original_timeout = BUILTIN_SUBAGENTS["bash"].timeout_seconds
|
||||
load_subagents_config_from_dict({"timeout_seconds": 42})
|
||||
original_max_turns = BUILTIN_SUBAGENTS["bash"].max_turns
|
||||
load_subagents_config_from_dict({"timeout_seconds": 42, "max_turns": 88})
|
||||
|
||||
returned = get_subagent_config("bash")
|
||||
assert returned.timeout_seconds == 42
|
||||
assert returned.max_turns == 88
|
||||
assert BUILTIN_SUBAGENTS["bash"].timeout_seconds == original_timeout
|
||||
assert BUILTIN_SUBAGENTS["bash"].max_turns == original_max_turns
|
||||
|
||||
def test_config_preserves_other_fields(self):
|
||||
"""Applying timeout override must not change other SubagentConfig fields."""
|
||||
"""Applying runtime overrides must not change other SubagentConfig fields."""
|
||||
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
|
||||
from deerflow.subagents.registry import get_subagent_config
|
||||
|
||||
_reset_subagents_config(timeout_seconds=300)
|
||||
_reset_subagents_config(timeout_seconds=300, max_turns=140)
|
||||
original = BUILTIN_SUBAGENTS["general-purpose"]
|
||||
overridden = get_subagent_config("general-purpose")
|
||||
|
||||
assert overridden.name == original.name
|
||||
assert overridden.description == original.description
|
||||
assert overridden.max_turns == original.max_turns
|
||||
assert overridden.max_turns == 140
|
||||
assert overridden.model == original.model
|
||||
assert overridden.tools == original.tools
|
||||
assert overridden.disallowed_tools == original.disallowed_tools
|
||||
@@ -291,9 +347,10 @@ class TestRegistryListSubagents:
|
||||
def test_all_returned_configs_get_global_override(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
|
||||
_reset_subagents_config(timeout_seconds=123)
|
||||
_reset_subagents_config(timeout_seconds=123, max_turns=77)
|
||||
for cfg in list_subagents():
|
||||
assert cfg.timeout_seconds == 123, f"{cfg.name} has wrong timeout"
|
||||
assert cfg.max_turns == 77, f"{cfg.name} has wrong max_turns"
|
||||
|
||||
def test_per_agent_overrides_reflected_in_list(self):
|
||||
from deerflow.subagents.registry import list_subagents
|
||||
@@ -301,15 +358,18 @@ class TestRegistryListSubagents:
|
||||
load_subagents_config_from_dict(
|
||||
{
|
||||
"timeout_seconds": 900,
|
||||
"max_turns": 120,
|
||||
"agents": {
|
||||
"general-purpose": {"timeout_seconds": 1800},
|
||||
"bash": {"timeout_seconds": 60},
|
||||
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
|
||||
"bash": {"timeout_seconds": 60, "max_turns": 80},
|
||||
},
|
||||
}
|
||||
)
|
||||
by_name = {cfg.name: cfg for cfg in list_subagents()}
|
||||
assert by_name["general-purpose"].timeout_seconds == 1800
|
||||
assert by_name["bash"].timeout_seconds == 60
|
||||
assert by_name["general-purpose"].max_turns == 200
|
||||
assert by_name["bash"].max_turns == 80
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.gateway.routers import suggestions
|
||||
|
||||
@@ -43,7 +43,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
|
||||
model_name=None,
|
||||
)
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.return_value = MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```')
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```'))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
@@ -61,7 +61,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
|
||||
model_name=None,
|
||||
)
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.return_value = MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}])
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}]))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
@@ -79,7 +79,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
|
||||
model_name=None,
|
||||
)
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.return_value = MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}])
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}]))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
@@ -94,7 +94,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
|
||||
model_name=None,
|
||||
)
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke.side_effect = RuntimeError("boom")
|
||||
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
result = asyncio.run(suggestions.generate_suggestions("t1", req))
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
"""Tests for thread_runs router with auth decorators.
|
||||
|
||||
These tests verify that auth decorators properly enforce permission checks
|
||||
on run endpoints. They follow the same pattern as test_threads_router.py.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.authz import AuthContext
|
||||
from app.gateway.routers.thread_runs import router
|
||||
|
||||
|
||||
def test_create_run_requires_auth():
|
||||
"""POST /{thread_id}/runs requires auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.post(
|
||||
"/api/threads/test-thread/runs",
|
||||
json={"assistant_id": "test"},
|
||||
)
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_create_run_with_auth():
|
||||
"""POST /{thread_id}/runs with valid auth passes through."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
mock_user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
mock_auth = AuthContext(
|
||||
user=mock_user,
|
||||
permissions=["runs:create", "threads:read", "threads:write"],
|
||||
)
|
||||
|
||||
# Mock the checkpointer and run_manager to avoid 503s
|
||||
mock_checkpointer = MagicMock()
|
||||
mock_run_manager = MagicMock()
|
||||
mock_run_manager.list_by_thread = MagicMock(return_value=[])
|
||||
mock_stream_bridge = MagicMock()
|
||||
|
||||
with patch("app.gateway.routers.thread_runs.get_checkpointer", return_value=mock_checkpointer):
|
||||
with patch("app.gateway.routers.thread_runs.get_run_manager", return_value=mock_run_manager):
|
||||
with patch("app.gateway.routers.thread_runs.get_stream_bridge", return_value=mock_stream_bridge):
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
# Without a real checkpointer.setup, this will 500 - but the point is auth passed
|
||||
response = client.post(
|
||||
"/api/threads/test-thread/runs",
|
||||
json={"assistant_id": "test"},
|
||||
)
|
||||
# Auth passed if we don't get 401
|
||||
assert response.status_code != 401
|
||||
|
||||
|
||||
def test_list_runs_requires_auth():
|
||||
"""GET /{thread_id}/runs requires auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.get("/api/threads/test-thread/runs")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_list_runs_with_auth():
|
||||
"""GET /{thread_id}/runs with auth passes through."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
mock_user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||
mock_auth = AuthContext(
|
||||
user=mock_user,
|
||||
permissions=["runs:read", "threads:read"],
|
||||
)
|
||||
|
||||
mock_run_manager = MagicMock()
|
||||
mock_run_manager.list_by_thread = MagicMock(return_value=[])
|
||||
|
||||
with patch("app.gateway.routers.thread_runs.get_run_manager", return_value=mock_run_manager):
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.get("/api/threads/test-thread/runs")
|
||||
# Should not be 401 (may be 500 or other, but auth passed)
|
||||
assert response.status_code != 401
|
||||
|
||||
|
||||
def test_get_run_requires_auth():
|
||||
"""GET /{thread_id}/runs/{run_id} requires auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.get("/api/threads/test-thread/runs/run-123")
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
def test_cancel_run_requires_auth():
|
||||
"""POST /{thread_id}/runs/{run_id}/cancel requires auth."""
|
||||
app = FastAPI()
|
||||
app.include_router(router)
|
||||
|
||||
with TestClient(app, raise_server_exceptions=False) as client:
|
||||
response = client.post("/api/threads/test-thread/runs/run-123/cancel")
|
||||
assert response.status_code == 401
|
||||
@@ -49,20 +49,43 @@ def test_delete_thread_data_rejects_invalid_thread_id(tmp_path):
|
||||
|
||||
|
||||
def test_delete_thread_route_cleans_thread_directory(tmp_path):
|
||||
"""DELETE /{thread_id} requires auth + permission — mock auth and store."""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.gateway.authz import AuthContext
|
||||
|
||||
tid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
|
||||
paths = Paths(tmp_path)
|
||||
thread_dir = paths.thread_dir("thread-route")
|
||||
paths.sandbox_work_dir("thread-route").mkdir(parents=True, exist_ok=True)
|
||||
(paths.sandbox_work_dir("thread-route") / "notes.txt").write_text("hello", encoding="utf-8")
|
||||
thread_dir = paths.thread_dir(tid)
|
||||
paths.sandbox_work_dir(tid).mkdir(parents=True, exist_ok=True)
|
||||
(paths.sandbox_work_dir(tid) / "notes.txt").write_text("hello", encoding="utf-8")
|
||||
|
||||
# Mock store item with .value attribute
|
||||
mock_store = MagicMock()
|
||||
mock_record = {
|
||||
"thread_id": tid,
|
||||
"metadata": {"user_id": "test-user-123"},
|
||||
}
|
||||
mock_store_item = MagicMock()
|
||||
mock_store_item.value = mock_record
|
||||
mock_store.aget = AsyncMock(return_value=mock_store_item)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = "test-user-123"
|
||||
mock_auth = AuthContext(user=mock_user, permissions=["threads:read", "threads:write", "threads:delete"])
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(threads.router)
|
||||
|
||||
with patch("app.gateway.routers.threads.get_paths", return_value=paths):
|
||||
with TestClient(app) as client:
|
||||
response = client.delete("/api/threads/thread-route")
|
||||
with patch("app.gateway.routers.threads.get_store", return_value=mock_store):
|
||||
with patch("app.gateway.routers.threads.get_checkpointer", return_value=MagicMock()):
|
||||
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||
with TestClient(app) as client:
|
||||
response = client.delete(f"/api/threads/{tid}")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"success": True, "message": "Deleted local thread data for thread-route"}
|
||||
assert response.json() == {"success": True, "message": f"Deleted local thread data for {tid}"}
|
||||
assert not thread_dir.exists()
|
||||
|
||||
|
||||
@@ -80,6 +103,7 @@ def test_delete_thread_route_rejects_invalid_thread_id(tmp_path):
|
||||
|
||||
|
||||
def test_delete_thread_route_returns_422_for_route_safe_invalid_id(tmp_path):
|
||||
"""DELETE /{thread_id} with non-UUID id — FastAPI rejects at path validation."""
|
||||
paths = Paths(tmp_path)
|
||||
|
||||
app = FastAPI()
|
||||
@@ -90,7 +114,9 @@ def test_delete_thread_route_returns_422_for_route_safe_invalid_id(tmp_path):
|
||||
response = client.delete("/api/threads/thread.with.dot")
|
||||
|
||||
assert response.status_code == 422
|
||||
assert "Invalid thread_id" in response.json()["detail"]
|
||||
# FastAPI returns a list of validation errors for path parameter mismatch
|
||||
detail = response.json()["detail"]
|
||||
assert any("thread_id" in str(err) for err in detail)
|
||||
|
||||
|
||||
def test_delete_thread_data_returns_generic_500_error(tmp_path):
|
||||
|
||||
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.agents.middlewares import title_middleware as title_middleware_module
|
||||
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
|
||||
|
||||
@@ -73,37 +74,32 @@ class TestTitleMiddlewareCoreLogic:
|
||||
|
||||
assert middleware._should_generate_title(state) is False
|
||||
|
||||
def test_generate_title_trims_quotes_and_respects_max_chars(self, monkeypatch):
|
||||
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=12)
|
||||
middleware = TitleMiddleware()
|
||||
fake_model = MagicMock()
|
||||
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='"A very long generated title"'))
|
||||
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="请帮我写一个脚本"),
|
||||
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
|
||||
AIMessage(content="好的,先确认需求"),
|
||||
]
|
||||
}
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
title = result["title"]
|
||||
|
||||
assert '"' not in title
|
||||
assert "'" not in title
|
||||
assert len(title) == 12
|
||||
assert title == "短标题"
|
||||
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
|
||||
model.ainvoke.assert_awaited_once()
|
||||
|
||||
def test_generate_title_normalizes_structured_message_and_response_content(self, monkeypatch):
|
||||
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
fake_model = MagicMock()
|
||||
fake_model.ainvoke = AsyncMock(
|
||||
return_value=MagicMock(content=[{"type": "text", "text": '"结构总结"'}]),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"deerflow.agents.middlewares.title_middleware.create_chat_model",
|
||||
lambda **kwargs: fake_model,
|
||||
)
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
@@ -115,21 +111,14 @@ class TestTitleMiddlewareCoreLogic:
|
||||
result = asyncio.run(middleware._agenerate_title_result(state))
|
||||
title = result["title"]
|
||||
|
||||
prompt = fake_model.ainvoke.await_args.args[0]
|
||||
assert "请帮我总结这段代码" in prompt
|
||||
assert "好的,先看结构" in prompt
|
||||
# Ensure structured message dict/JSON reprs are not leaking into the prompt.
|
||||
assert "{'type':" not in prompt
|
||||
assert "'type':" not in prompt
|
||||
assert '"type":' not in prompt
|
||||
assert title == "结构总结"
|
||||
assert title == "请帮我总结这段代码"
|
||||
|
||||
def test_generate_title_fallback_when_model_fails(self, monkeypatch):
|
||||
def test_generate_title_fallback_for_long_message(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
fake_model = MagicMock()
|
||||
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
|
||||
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
|
||||
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
@@ -164,13 +153,10 @@ class TestTitleMiddlewareCoreLogic:
|
||||
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
|
||||
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
|
||||
|
||||
def test_sync_generate_title_with_model(self, monkeypatch):
|
||||
"""Sync path calls model.invoke and produces a title."""
|
||||
def test_sync_generate_title_uses_fallback_without_model(self):
|
||||
"""Sync path avoids LLM calls and derives a local fallback title."""
|
||||
_set_test_title_config(max_chars=20)
|
||||
middleware = TitleMiddleware()
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke = MagicMock(return_value=MagicMock(content='"同步生成的标题"'))
|
||||
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
@@ -179,22 +165,19 @@ class TestTitleMiddlewareCoreLogic:
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state)
|
||||
assert result == {"title": "同步生成的标题"}
|
||||
fake_model.invoke.assert_called_once()
|
||||
assert result == {"title": "请帮我写测试"}
|
||||
|
||||
def test_empty_title_falls_back(self, monkeypatch):
|
||||
"""Empty model response triggers fallback title."""
|
||||
def test_sync_generate_title_respects_fallback_truncation(self):
|
||||
"""Sync fallback path still respects max_chars truncation rules."""
|
||||
_set_test_title_config(max_chars=50)
|
||||
middleware = TitleMiddleware()
|
||||
fake_model = MagicMock()
|
||||
fake_model.invoke = MagicMock(return_value=MagicMock(content=" "))
|
||||
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
|
||||
|
||||
state = {
|
||||
"messages": [
|
||||
HumanMessage(content="空标题测试"),
|
||||
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题,而且这里继续补充更多上下文,确保超过本地fallback截断阈值"),
|
||||
AIMessage(content="回复"),
|
||||
]
|
||||
}
|
||||
result = middleware._generate_title_result(state)
|
||||
assert result["title"] == "空标题测试"
|
||||
assert result["title"].endswith("...")
|
||||
assert result["title"].startswith("这是一个非常长的问题描述")
|
||||
|
||||
@@ -289,6 +289,8 @@ class TestBeforeAgent:
|
||||
"size": 5,
|
||||
"path": "/mnt/user-data/uploads/notes.txt",
|
||||
"extension": ".txt",
|
||||
"outline": [],
|
||||
"outline_preview": [],
|
||||
}
|
||||
]
|
||||
|
||||
@@ -339,3 +341,130 @@ class TestBeforeAgent:
|
||||
result = mw.before_agent(self._state(msg), _runtime())
|
||||
|
||||
assert result["messages"][-1].id == "original-id-42"
|
||||
|
||||
def test_outline_injected_when_md_file_exists(self, tmp_path):
|
||||
"""When a converted .md file exists alongside the upload, its outline is injected."""
|
||||
mw = _middleware(tmp_path)
|
||||
uploads_dir = _uploads_dir(tmp_path)
|
||||
(uploads_dir / "report.pdf").write_bytes(b"%PDF fake")
|
||||
# Simulate the .md produced by the conversion pipeline
|
||||
(uploads_dir / "report.md").write_text(
|
||||
"# PART I\n\n## ITEM 1. BUSINESS\n\nBody text.\n\n## ITEM 2. RISK\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
msg = _human("summarise", files=[{"filename": "report.pdf", "size": 9, "path": "/mnt/user-data/uploads/report.pdf"}])
|
||||
result = mw.before_agent(self._state(msg), _runtime())
|
||||
|
||||
assert result is not None
|
||||
content = result["messages"][-1].content
|
||||
assert "Document outline" in content
|
||||
assert "PART I" in content
|
||||
assert "ITEM 1. BUSINESS" in content
|
||||
assert "ITEM 2. RISK" in content
|
||||
assert "read_file" in content
|
||||
|
||||
def test_no_outline_when_no_md_file(self, tmp_path):
|
||||
"""Files without a sibling .md have no outline section."""
|
||||
mw = _middleware(tmp_path)
|
||||
uploads_dir = _uploads_dir(tmp_path)
|
||||
(uploads_dir / "data.xlsx").write_bytes(b"fake-xlsx")
|
||||
|
||||
msg = _human("analyse", files=[{"filename": "data.xlsx", "size": 9, "path": "/mnt/user-data/uploads/data.xlsx"}])
|
||||
result = mw.before_agent(self._state(msg), _runtime())
|
||||
|
||||
assert result is not None
|
||||
content = result["messages"][-1].content
|
||||
assert "Document outline" not in content
|
||||
|
||||
def test_outline_truncation_hint_shown(self, tmp_path):
|
||||
"""When outline is truncated, a hint line is appended after the last visible entry."""
|
||||
from deerflow.utils.file_conversion import MAX_OUTLINE_ENTRIES
|
||||
|
||||
mw = _middleware(tmp_path)
|
||||
uploads_dir = _uploads_dir(tmp_path)
|
||||
(uploads_dir / "big.pdf").write_bytes(b"%PDF fake")
|
||||
# Write MAX_OUTLINE_ENTRIES + 5 headings so truncation is triggered
|
||||
headings = "\n".join(f"# Heading {i}" for i in range(MAX_OUTLINE_ENTRIES + 5))
|
||||
(uploads_dir / "big.md").write_text(headings, encoding="utf-8")
|
||||
|
||||
msg = _human("read", files=[{"filename": "big.pdf", "size": 9, "path": "/mnt/user-data/uploads/big.pdf"}])
|
||||
result = mw.before_agent(self._state(msg), _runtime())
|
||||
|
||||
assert result is not None
|
||||
content = result["messages"][-1].content
|
||||
assert f"showing first {MAX_OUTLINE_ENTRIES} headings" in content
|
||||
assert "use `read_file` to explore further" in content
|
||||
|
||||
def test_no_truncation_hint_for_short_outline(self, tmp_path):
|
||||
"""Short outlines (under the cap) must not show a truncation hint."""
|
||||
mw = _middleware(tmp_path)
|
||||
uploads_dir = _uploads_dir(tmp_path)
|
||||
(uploads_dir / "short.pdf").write_bytes(b"%PDF fake")
|
||||
(uploads_dir / "short.md").write_text("# Intro\n\n# Conclusion\n", encoding="utf-8")
|
||||
|
||||
msg = _human("read", files=[{"filename": "short.pdf", "size": 9, "path": "/mnt/user-data/uploads/short.pdf"}])
|
||||
result = mw.before_agent(self._state(msg), _runtime())
|
||||
|
||||
assert result is not None
|
||||
content = result["messages"][-1].content
|
||||
assert "showing first" not in content
|
||||
|
||||
def test_historical_file_outline_injected(self, tmp_path):
|
||||
"""Outline is also shown for historical (previously uploaded) files."""
|
||||
mw = _middleware(tmp_path)
|
||||
uploads_dir = _uploads_dir(tmp_path)
|
||||
# Historical file with .md
|
||||
(uploads_dir / "old_report.pdf").write_bytes(b"%PDF old")
|
||||
(uploads_dir / "old_report.md").write_text(
|
||||
"# Chapter 1\n\n# Chapter 2\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
# New file without .md
|
||||
(uploads_dir / "new.txt").write_bytes(b"new")
|
||||
|
||||
msg = _human("go", files=[{"filename": "new.txt", "size": 3, "path": "/mnt/user-data/uploads/new.txt"}])
|
||||
result = mw.before_agent(self._state(msg), _runtime())
|
||||
|
||||
assert result is not None
|
||||
content = result["messages"][-1].content
|
||||
assert "Chapter 1" in content
|
||||
assert "Chapter 2" in content
|
||||
|
||||
def test_fallback_preview_shown_when_outline_empty(self, tmp_path):
|
||||
"""When .md exists but has no headings, first lines are shown as a preview."""
|
||||
mw = _middleware(tmp_path)
|
||||
uploads_dir = _uploads_dir(tmp_path)
|
||||
(uploads_dir / "report.pdf").write_bytes(b"%PDF fake")
|
||||
# .md with no # headings — plain prose only
|
||||
(uploads_dir / "report.md").write_text(
|
||||
"Annual Financial Report 2024\n\nThis document summarises key findings.\n\nRevenue grew by 12%.\n",
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
msg = _human("analyse", files=[{"filename": "report.pdf", "size": 9, "path": "/mnt/user-data/uploads/report.pdf"}])
|
||||
result = mw.before_agent(self._state(msg), _runtime())
|
||||
|
||||
assert result is not None
|
||||
content = result["messages"][-1].content
|
||||
# Outline section must NOT appear
|
||||
assert "Document outline" not in content
|
||||
# Preview lines must appear
|
||||
assert "Annual Financial Report 2024" in content
|
||||
assert "No structural headings detected" in content
|
||||
# grep hint must appear
|
||||
assert "grep" in content
|
||||
|
||||
def test_fallback_grep_hint_shown_when_no_md_file(self, tmp_path):
|
||||
"""Files with no sibling .md still get the grep hint (outline is empty)."""
|
||||
mw = _middleware(tmp_path)
|
||||
uploads_dir = _uploads_dir(tmp_path)
|
||||
(uploads_dir / "data.csv").write_bytes(b"a,b,c\n1,2,3\n")
|
||||
|
||||
msg = _human("analyse", files=[{"filename": "data.csv", "size": 12, "path": "/mnt/user-data/uploads/data.csv"}])
|
||||
result = mw.before_agent(self._state(msg), _runtime())
|
||||
|
||||
assert result is not None
|
||||
content = result["messages"][-1].content
|
||||
assert "Document outline" not in content
|
||||
assert "grep" in content
|
||||
|
||||
Reference in New Issue
Block a user