feat(auth): authentication module with multi-tenant isolation (RFC-001)

Introduce an always-on auth layer with auto-created admin on first boot,
multi-tenant isolation for threads/stores, and a full setup/login flow.

Backend
- JWT access tokens with `ver` field for stale-token rejection; bump on
  password/email change
- Password hashing, HttpOnly+Secure cookies (Secure derived from request
  scheme at runtime)
- CSRF middleware covering both REST and LangGraph routes
- IP-based login rate limiting (5 attempts / 5-min lockout) with bounded
  dict growth and X-Forwarded-For bypass fix
- Multi-worker-safe admin auto-creation (single DB write, WAL once)
- needs_setup + token_version on User model; SQLite schema migration
- Thread/store isolation by owner; orphan thread migration on first admin
  registration
- thread_id validated as UUID to prevent log injection
- CLI tool to reset admin password
- Decorator-based authz module extracted from auth core

Frontend
- Login and setup pages with SSR guard for needs_setup flow
- Account settings page (change password / email)
- AuthProvider + route guards; skips redirect when no users registered
- i18n (en-US / zh-CN) for auth surfaces
- Typed auth API client; parseAuthError unwraps FastAPI detail envelope

Infra & tooling
- Unified `serve.sh` with gateway mode + auto dep install
- Public PyPI uv.toml pin for CI compatibility
- Regenerated uv.lock with public index

Tests
- HTTP vs HTTPS cookie security tests
- Auth middleware, rate limiter, CSRF, setup flow coverage
This commit is contained in:
greatmengqi
2026-04-08 00:31:43 +08:00
parent 636053fb6d
commit 27b66d6753
214 changed files with 18830 additions and 1065 deletions
+506
View File
@@ -0,0 +1,506 @@
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
from datetime import timedelta
from unittest.mock import AsyncMock, MagicMock, patch
from uuid import uuid4
import pytest
from fastapi import FastAPI, HTTPException
from fastapi.testclient import TestClient
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
from app.gateway.auth.models import User
from app.gateway.authz import (
AuthContext,
Permissions,
get_auth_context,
require_auth,
require_permission,
)
# ── Password Hashing ────────────────────────────────────────────────────────
def test_hash_password_and_verify():
"""Hashing and verification round-trip."""
password = "s3cr3tP@ssw0rd!"
hashed = hash_password(password)
assert hashed != password
assert verify_password(password, hashed) is True
assert verify_password("wrongpassword", hashed) is False
def test_hash_password_different_each_time():
"""bcrypt generates unique salts, so same password has different hashes."""
password = "testpassword"
h1 = hash_password(password)
h2 = hash_password(password)
assert h1 != h2 # Different salts
# But both verify correctly
assert verify_password(password, h1) is True
assert verify_password(password, h2) is True
def test_verify_password_rejects_empty():
"""Empty password should not verify."""
hashed = hash_password("nonempty")
assert verify_password("", hashed) is False
# ── JWT ─────────────────────────────────────────────────────────────────────
def test_create_and_decode_token():
"""JWT creation and decoding round-trip."""
user_id = str(uuid4())
# Set a valid JWT secret for this test
import os
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(user_id)
assert isinstance(token, str)
payload = decode_token(token)
assert payload is not None
assert payload.sub == user_id
def test_decode_token_expired():
"""Expired token returns TokenError.EXPIRED."""
from app.gateway.auth.errors import TokenError
user_id = str(uuid4())
# Create token that expires immediately
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
payload = decode_token(token)
assert payload == TokenError.EXPIRED
def test_decode_token_invalid():
"""Invalid token returns TokenError."""
from app.gateway.auth.errors import TokenError
assert isinstance(decode_token("not.a.valid.token"), TokenError)
assert isinstance(decode_token(""), TokenError)
assert isinstance(decode_token("completely-wrong"), TokenError)
def test_create_token_custom_expiry():
"""Custom expiry is respected."""
user_id = str(uuid4())
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
payload = decode_token(token)
assert payload is not None
assert payload.sub == user_id
# ── AuthContext ────────────────────────────────────────────────────────────
def test_auth_context_unauthenticated():
"""AuthContext with no user."""
ctx = AuthContext(user=None, permissions=[])
assert ctx.is_authenticated is False
assert ctx.has_permission("threads", "read") is False
def test_auth_context_authenticated_no_perms():
"""AuthContext with user but no permissions."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[])
assert ctx.is_authenticated is True
assert ctx.has_permission("threads", "read") is False
def test_auth_context_has_permission():
"""AuthContext permission checking."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
ctx = AuthContext(user=user, permissions=perms)
assert ctx.has_permission("threads", "read") is True
assert ctx.has_permission("threads", "write") is True
assert ctx.has_permission("threads", "delete") is False
assert ctx.has_permission("runs", "read") is False
def test_auth_context_require_user_raises():
"""require_user raises 401 when not authenticated."""
ctx = AuthContext(user=None, permissions=[])
with pytest.raises(HTTPException) as exc_info:
ctx.require_user()
assert exc_info.value.status_code == 401
def test_auth_context_require_user_returns_user():
"""require_user returns user when authenticated."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[])
returned = ctx.require_user()
assert returned == user
# ── get_auth_context helper ─────────────────────────────────────────────────
def test_get_auth_context_not_set():
"""get_auth_context returns None when auth not set on request."""
mock_request = MagicMock()
# Make getattr return None (simulating attribute not set)
mock_request.state = MagicMock()
del mock_request.state.auth
assert get_auth_context(mock_request) is None
def test_get_auth_context_set():
"""get_auth_context returns the AuthContext from request."""
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
mock_request = MagicMock()
mock_request.state.auth = ctx
assert get_auth_context(mock_request) == ctx
# ── require_auth decorator ──────────────────────────────────────────────────
def test_require_auth_sets_auth_context():
"""require_auth sets auth context on request from cookie."""
from fastapi import Request
app = FastAPI()
@app.get("/test")
@require_auth
async def endpoint(request: Request):
ctx = get_auth_context(request)
return {"authenticated": ctx.is_authenticated}
with TestClient(app) as client:
# No cookie → anonymous
response = client.get("/test")
assert response.status_code == 200
assert response.json()["authenticated"] is False
def test_require_auth_requires_request_param():
"""require_auth raises ValueError if request parameter is missing."""
import asyncio
@require_auth
async def bad_endpoint(): # Missing `request` parameter
pass
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
asyncio.run(bad_endpoint())
# ── require_permission decorator ─────────────────────────────────────────────
def test_require_permission_requires_auth():
"""require_permission raises 401 when not authenticated."""
from fastapi import Request
app = FastAPI()
@app.get("/test")
@require_permission("threads", "read")
async def endpoint(request: Request):
return {"ok": True}
with TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 401
assert "Authentication required" in response.json()["detail"]
def test_require_permission_denies_wrong_permission():
"""User without required permission gets 403."""
from fastapi import Request
app = FastAPI()
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
@app.get("/test")
@require_permission("threads", "delete")
async def endpoint(request: Request):
return {"ok": True}
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
with TestClient(app) as client:
response = client.get("/test")
assert response.status_code == 403
assert "Permission denied" in response.json()["detail"]
# ── Weak JWT secret warning ──────────────────────────────────────────────────
# ── User Model Fields ──────────────────────────────────────────────────────
def test_user_model_has_needs_setup_default_false():
"""New users default to needs_setup=False."""
user = User(email="test@example.com", password_hash="hash")
assert user.needs_setup is False
def test_user_model_has_token_version_default_zero():
"""New users default to token_version=0."""
user = User(email="test@example.com", password_hash="hash")
assert user.token_version == 0
def test_user_model_needs_setup_true():
"""Auto-created admin has needs_setup=True."""
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
assert user.needs_setup is True
def test_sqlite_round_trip_new_fields():
"""needs_setup and token_version survive create → read round-trip."""
import asyncio
import os
import tempfile
from pathlib import Path
from app.gateway.auth.repositories import sqlite as sqlite_mod
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test_users.db")
old_path = sqlite_mod._resolved_db_path
old_init = sqlite_mod._table_initialized
sqlite_mod._resolved_db_path = Path(db_path)
sqlite_mod._table_initialized = False
try:
repo = sqlite_mod.SQLiteUserRepository()
user = User(
email="setup@test.com",
password_hash="fakehash",
system_role="admin",
needs_setup=True,
token_version=3,
)
created = asyncio.run(repo.create_user(user))
assert created.needs_setup is True
assert created.token_version == 3
fetched = asyncio.run(repo.get_user_by_email("setup@test.com"))
assert fetched is not None
assert fetched.needs_setup is True
assert fetched.token_version == 3
fetched.needs_setup = False
fetched.token_version = 4
asyncio.run(repo.update_user(fetched))
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id)))
assert refetched.needs_setup is False
assert refetched.token_version == 4
finally:
sqlite_mod._resolved_db_path = old_path
sqlite_mod._table_initialized = old_init
# ── Token Versioning ───────────────────────────────────────────────────────
def test_jwt_encodes_ver():
"""JWT payload includes ver field."""
import os
from app.gateway.auth.errors import TokenError
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(str(uuid4()), token_version=3)
payload = decode_token(token)
assert not isinstance(payload, TokenError)
assert payload.ver == 3
def test_jwt_default_ver_zero():
"""JWT ver defaults to 0."""
import os
from app.gateway.auth.errors import TokenError
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
token = create_access_token(str(uuid4()))
payload = decode_token(token)
assert not isinstance(payload, TokenError)
assert payload.ver == 0
def test_token_version_mismatch_rejects():
"""Token with stale ver is rejected by get_current_user_from_request."""
import asyncio
import os
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
user_id = str(uuid4())
token = create_access_token(user_id, token_version=0)
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
mock_request = MagicMock()
mock_request.cookies = {"access_token": token}
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
mock_provider = MagicMock()
mock_provider.get_user = AsyncMock(return_value=mock_user)
mock_provider_fn.return_value = mock_provider
from app.gateway.deps import get_current_user_from_request
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
assert "revoked" in str(exc_info.value.detail).lower()
# ── change-password extension ──────────────────────────────────────────────
def test_change_password_request_accepts_new_email():
"""ChangePasswordRequest model accepts optional new_email."""
from app.gateway.routers.auth import ChangePasswordRequest
req = ChangePasswordRequest(
current_password="old",
new_password="newpassword",
new_email="new@example.com",
)
assert req.new_email == "new@example.com"
def test_change_password_request_new_email_optional():
"""ChangePasswordRequest model works without new_email."""
from app.gateway.routers.auth import ChangePasswordRequest
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
assert req.new_email is None
def test_login_response_includes_needs_setup():
"""LoginResponse includes needs_setup field."""
from app.gateway.routers.auth import LoginResponse
resp = LoginResponse(expires_in=3600, needs_setup=True)
assert resp.needs_setup is True
resp2 = LoginResponse(expires_in=3600)
assert resp2.needs_setup is False
# ── Rate Limiting ──────────────────────────────────────────────────────────
def test_rate_limiter_allows_under_limit():
"""Requests under the limit are allowed."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
_login_attempts.clear()
_check_rate_limit("192.168.1.1") # Should not raise
def test_rate_limiter_blocks_after_max_failures():
"""IP is blocked after 5 consecutive failures."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
_login_attempts.clear()
ip = "10.0.0.1"
for _ in range(5):
_record_login_failure(ip)
with pytest.raises(HTTPException) as exc_info:
_check_rate_limit(ip)
assert exc_info.value.status_code == 429
def test_rate_limiter_resets_on_success():
"""Successful login clears the failure counter."""
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
_login_attempts.clear()
ip = "10.0.0.2"
for _ in range(4):
_record_login_failure(ip)
_record_login_success(ip)
_check_rate_limit(ip) # Should not raise
# ── Client IP extraction ─────────────────────────────────────────────────
def test_get_client_ip_direct_connection():
"""Without nginx (no X-Real-IP), falls back to request.client.host."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "203.0.113.42"
req.headers = {}
assert _get_client_ip(req) == "203.0.113.42"
def test_get_client_ip_uses_x_real_ip():
"""X-Real-IP (set by nginx) is used when present."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "10.0.0.1" # uvicorn may have replaced this with XFF[0]
req.headers = {"x-real-ip": "203.0.113.42"}
assert _get_client_ip(req) == "203.0.113.42"
def test_get_client_ip_xff_ignored():
"""X-Forwarded-For is never used; only X-Real-IP matters."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "10.0.0.1"
req.headers = {"x-forwarded-for": "10.0.0.1, 198.51.100.5", "x-real-ip": "198.51.100.5"}
assert _get_client_ip(req) == "198.51.100.5"
def test_get_client_ip_no_real_ip_fallback():
"""No X-Real-IP → falls back to client.host (direct connection)."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "127.0.0.1"
req.headers = {}
assert _get_client_ip(req) == "127.0.0.1"
def test_get_client_ip_x_real_ip_always_preferred():
"""X-Real-IP is always preferred over client.host regardless of IP."""
from app.gateway.routers.auth import _get_client_ip
req = MagicMock()
req.client.host = "203.0.113.99"
req.headers = {"x-real-ip": "198.51.100.7"}
assert _get_client_ip(req) == "198.51.100.7"
# ── Weak JWT secret warning ──────────────────────────────────────────────────
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
import logging
import app.gateway.auth.config as config_module
config_module._auth_config = None
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
with caplog.at_level(logging.WARNING):
config = config_module.get_auth_config()
assert config.jwt_secret # non-empty ephemeral secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
# Cleanup
config_module._auth_config = None
+54
View File
@@ -0,0 +1,54 @@
"""Tests for AuthConfig typed configuration."""
import os
from unittest.mock import patch
import pytest
from app.gateway.auth.config import AuthConfig
def test_auth_config_defaults():
config = AuthConfig(jwt_secret="test-secret-key-123")
assert config.token_expiry_days == 7
def test_auth_config_token_expiry_range():
AuthConfig(jwt_secret="s", token_expiry_days=1)
AuthConfig(jwt_secret="s", token_expiry_days=30)
with pytest.raises(Exception):
AuthConfig(jwt_secret="s", token_expiry_days=0)
with pytest.raises(Exception):
AuthConfig(jwt_secret="s", token_expiry_days=31)
def test_auth_config_from_env():
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
with patch.dict(os.environ, env, clear=False):
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
config = cfg.get_auth_config()
assert config.jwt_secret == "test-jwt-secret-from-env"
finally:
cfg._auth_config = old
def test_auth_config_missing_secret_generates_ephemeral(caplog):
import logging
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with caplog.at_level(logging.WARNING):
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
finally:
cfg._auth_config = old
+75
View File
@@ -0,0 +1,75 @@
"""Tests for auth error types and typed decode_token."""
from datetime import UTC, datetime, timedelta
import jwt as pyjwt
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import create_access_token, decode_token
def test_auth_error_code_values():
assert AuthErrorCode.INVALID_CREDENTIALS == "invalid_credentials"
assert AuthErrorCode.TOKEN_EXPIRED == "token_expired"
assert AuthErrorCode.NOT_AUTHENTICATED == "not_authenticated"
def test_token_error_values():
assert TokenError.EXPIRED == "expired"
assert TokenError.INVALID_SIGNATURE == "invalid_signature"
assert TokenError.MALFORMED == "malformed"
def test_auth_error_response_serialization():
err = AuthErrorResponse(
code=AuthErrorCode.TOKEN_EXPIRED,
message="Token has expired",
)
d = err.model_dump()
assert d == {"code": "token_expired", "message": "Token has expired"}
def test_auth_error_response_from_dict():
d = {"code": "invalid_credentials", "message": "Wrong password"}
err = AuthErrorResponse(**d)
assert err.code == AuthErrorCode.INVALID_CREDENTIALS
# ── decode_token typed failure tests ──────────────────────────────
_TEST_SECRET = "test-secret-for-jwt-decode-token-tests"
def _setup_config():
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
def test_decode_token_returns_token_error_on_expired():
_setup_config()
expired_payload = {"sub": "user-1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired_payload, _TEST_SECRET, algorithm="HS256")
result = decode_token(token)
assert result == TokenError.EXPIRED
def test_decode_token_returns_token_error_on_bad_signature():
_setup_config()
payload = {"sub": "user-1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
result = decode_token(token)
assert result == TokenError.INVALID_SIGNATURE
def test_decode_token_returns_token_error_on_malformed():
_setup_config()
result = decode_token("not-a-jwt")
assert result == TokenError.MALFORMED
def test_decode_token_returns_payload_on_valid():
_setup_config()
token = create_access_token("user-123")
result = decode_token(token)
assert not isinstance(result, TokenError)
assert result.sub == "user-123"
+216
View File
@@ -0,0 +1,216 @@
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
import pytest
from starlette.testclient import TestClient
from app.gateway.auth_middleware import AuthMiddleware, _is_public
# ── _is_public unit tests ─────────────────────────────────────────────────
@pytest.mark.parametrize(
"path",
[
"/health",
"/health/",
"/docs",
"/docs/",
"/redoc",
"/openapi.json",
"/api/v1/auth/login/local",
"/api/v1/auth/register",
"/api/v1/auth/logout",
"/api/v1/auth/setup-status",
],
)
def test_public_paths(path: str):
assert _is_public(path) is True
@pytest.mark.parametrize(
"path",
[
"/api/models",
"/api/mcp/config",
"/api/memory",
"/api/skills",
"/api/threads/123",
"/api/threads/123/uploads",
"/api/agents",
"/api/channels",
"/api/runs/stream",
"/api/threads/123/runs",
"/api/v1/auth/me",
"/api/v1/auth/change-password",
],
)
def test_protected_paths(path: str):
assert _is_public(path) is False
# ── Trailing slash / normalization edge cases ─────────────────────────────
@pytest.mark.parametrize(
"path",
[
"/api/v1/auth/login/local/",
"/api/v1/auth/register/",
"/api/v1/auth/logout/",
"/api/v1/auth/setup-status/",
],
)
def test_public_auth_paths_with_trailing_slash(path: str):
assert _is_public(path) is True
@pytest.mark.parametrize(
"path",
[
"/api/models/",
"/api/v1/auth/me/",
"/api/v1/auth/change-password/",
],
)
def test_protected_paths_with_trailing_slash(path: str):
assert _is_public(path) is False
def test_unknown_api_path_is_protected():
"""Fail-closed: any new /api/* path is protected by default."""
assert _is_public("/api/new-feature") is False
assert _is_public("/api/v2/something") is False
assert _is_public("/api/v1/auth/new-endpoint") is False
# ── Middleware integration tests ──────────────────────────────────────────
def _make_app():
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
from fastapi import FastAPI
app = FastAPI()
app.add_middleware(AuthMiddleware)
@app.get("/health")
async def health():
return {"status": "ok"}
@app.get("/api/v1/auth/me")
async def auth_me():
return {"id": "1", "email": "test@test.com"}
@app.get("/api/v1/auth/setup-status")
async def setup_status():
return {"needs_setup": False}
@app.get("/api/models")
async def models_get():
return {"models": []}
@app.put("/api/mcp/config")
async def mcp_put():
return {"ok": True}
@app.delete("/api/threads/abc")
async def thread_delete():
return {"ok": True}
@app.patch("/api/threads/abc")
async def thread_patch():
return {"ok": True}
@app.post("/api/threads/abc/runs/stream")
async def stream():
return {"ok": True}
@app.get("/api/future-endpoint")
async def future():
return {"ok": True}
return app
@pytest.fixture
def client():
return TestClient(_make_app())
def test_public_path_no_cookie(client):
res = client.get("/health")
assert res.status_code == 200
def test_public_auth_path_no_cookie(client):
"""Public auth endpoints (login/register) pass without cookie."""
res = client.get("/api/v1/auth/setup-status")
assert res.status_code == 200
def test_protected_auth_path_no_cookie(client):
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
res = client.get("/api/v1/auth/me")
assert res.status_code == 401
def test_protected_path_no_cookie_returns_401(client):
res = client.get("/api/models")
assert res.status_code == 401
body = res.json()
assert body["detail"]["code"] == "not_authenticated"
def test_protected_path_with_cookie_passes(client):
res = client.get("/api/models", cookies={"access_token": "some-token"})
assert res.status_code == 200
def test_protected_post_no_cookie_returns_401(client):
res = client.post("/api/threads/abc/runs/stream")
assert res.status_code == 401
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
def test_protected_put_no_cookie(client):
res = client.put("/api/mcp/config")
assert res.status_code == 401
def test_protected_delete_no_cookie(client):
res = client.delete("/api/threads/abc")
assert res.status_code == 401
def test_protected_patch_no_cookie(client):
res = client.patch("/api/threads/abc")
assert res.status_code == 401
def test_put_with_cookie_passes(client):
client.cookies.set("access_token", "tok")
res = client.put("/api/mcp/config")
assert res.status_code == 200
def test_delete_with_cookie_passes(client):
client.cookies.set("access_token", "tok")
res = client.delete("/api/threads/abc")
assert res.status_code == 200
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
def test_unknown_endpoint_no_cookie_returns_401(client):
"""Any new /api/* endpoint is blocked by default without cookie."""
res = client.get("/api/future-endpoint")
assert res.status_code == 401
def test_unknown_endpoint_with_cookie_passes(client):
client.cookies.set("access_token", "tok")
res = client.get("/api/future-endpoint")
assert res.status_code == 200
+675
View File
@@ -0,0 +1,675 @@
"""Tests for auth type system hardening.
Covers structured error responses, typed decode_token callers,
CSRF middleware path matching, config-driven cookie security,
and unhappy paths / edge cases for all auth boundaries.
"""
import os
import secrets
from datetime import UTC, datetime, timedelta
from unittest.mock import patch
import jwt as pyjwt
import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from pydantic import ValidationError
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
from app.gateway.auth.jwt import decode_token
from app.gateway.csrf_middleware import (
CSRF_COOKIE_NAME,
CSRF_HEADER_NAME,
CSRFMiddleware,
is_auth_endpoint,
should_check_csrf,
)
# ── Setup ────────────────────────────────────────────────────────────
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
def _setup_config():
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
# ── CSRF Middleware Path Matching ────────────────────────────────────
class _FakeRequest:
"""Minimal request mock for CSRF path matching tests."""
def __init__(self, path: str, method: str = "POST"):
self.method = method
class _URL:
def __init__(self, p):
self.path = p
self.url = _URL(path)
self.cookies = {}
self.headers = {}
def test_csrf_exempts_login_local():
"""login/local (actual route) should be exempt from CSRF."""
req = _FakeRequest("/api/v1/auth/login/local")
assert is_auth_endpoint(req) is True
def test_csrf_exempts_login_local_trailing_slash():
"""Trailing slash should also be exempt."""
req = _FakeRequest("/api/v1/auth/login/local/")
assert is_auth_endpoint(req) is True
def test_csrf_exempts_logout():
req = _FakeRequest("/api/v1/auth/logout")
assert is_auth_endpoint(req) is True
def test_csrf_exempts_register():
req = _FakeRequest("/api/v1/auth/register")
assert is_auth_endpoint(req) is True
def test_csrf_does_not_exempt_old_login_path():
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
req = _FakeRequest("/api/v1/auth/login")
assert is_auth_endpoint(req) is False
def test_csrf_does_not_exempt_me():
req = _FakeRequest("/api/v1/auth/me")
assert is_auth_endpoint(req) is False
def test_csrf_skips_get_requests():
req = _FakeRequest("/api/v1/auth/me", method="GET")
assert should_check_csrf(req) is False
def test_csrf_checks_post_to_protected():
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
assert should_check_csrf(req) is True
# ── Structured Error Response Format ────────────────────────────────
def test_auth_error_response_has_code_and_message():
"""All auth errors should have structured {code, message} format."""
err = AuthErrorResponse(
code=AuthErrorCode.INVALID_CREDENTIALS,
message="Wrong password",
)
d = err.model_dump()
assert "code" in d
assert "message" in d
assert d["code"] == "invalid_credentials"
def test_auth_error_response_all_codes_serializable():
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
for code in AuthErrorCode:
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
d = err.model_dump()
assert d["code"] == code.value
# ── decode_token Caller Pattern ──────────────────────────────────────
def test_decode_token_expired_maps_to_token_expired_code():
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
_setup_config()
from datetime import UTC, datetime, timedelta
import jwt as pyjwt
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
result = decode_token(token)
assert result == TokenError.EXPIRED
# Verify the mapping pattern used in route handlers
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
assert code == AuthErrorCode.TOKEN_EXPIRED
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
_setup_config()
from datetime import UTC, datetime, timedelta
import jwt as pyjwt
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
result = decode_token(token)
assert result == TokenError.INVALID_SIGNATURE
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
assert code == AuthErrorCode.TOKEN_INVALID
def test_decode_token_malformed_maps_to_token_invalid_code():
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
_setup_config()
result = decode_token("garbage")
assert result == TokenError.MALFORMED
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
assert code == AuthErrorCode.TOKEN_INVALID
# ── Login Response Format ────────────────────────────────────────────
def test_login_response_model_has_no_access_token():
"""LoginResponse should NOT contain access_token field (RFC-001)."""
from app.gateway.routers.auth import LoginResponse
resp = LoginResponse(expires_in=604800)
d = resp.model_dump()
assert "access_token" not in d
assert "expires_in" in d
assert d["expires_in"] == 604800
def test_login_response_model_fields():
"""LoginResponse has expires_in and needs_setup."""
from app.gateway.routers.auth import LoginResponse
fields = set(LoginResponse.model_fields.keys())
assert fields == {"expires_in", "needs_setup"}
# ── AuthConfig in Route ──────────────────────────────────────────────
def test_auth_config_token_expiry_used_in_login_response():
"""LoginResponse.expires_in should come from config.token_expiry_days."""
from app.gateway.routers.auth import LoginResponse
expected_seconds = 14 * 24 * 3600
resp = LoginResponse(expires_in=expected_seconds)
assert resp.expires_in == expected_seconds
# ── UserResponse Type Preservation ───────────────────────────────────
def test_user_response_system_role_literal():
"""UserResponse.system_role should only accept 'admin' or 'user'."""
from app.gateway.auth.models import UserResponse
# Valid roles
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
assert resp.system_role == "admin"
resp = UserResponse(id="1", email="a@b.com", system_role="user")
assert resp.system_role == "user"
def test_user_response_rejects_invalid_role():
"""UserResponse should reject invalid system_role values."""
from app.gateway.auth.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com", system_role="superadmin")
# ══════════════════════════════════════════════════════════════════════
# UNHAPPY PATHS / EDGE CASES
# ══════════════════════════════════════════════════════════════════════
# ── get_current_user structured 401 responses ────────────────────────
def test_get_current_user_no_cookie_returns_not_authenticated():
"""No cookie → 401 with code=not_authenticated."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
mock_request = type("MockRequest", (), {"cookies": {}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "not_authenticated"
def test_get_current_user_expired_token_returns_token_expired():
"""Expired token → 401 with code=token_expired."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
_setup_config()
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "token_expired"
def test_get_current_user_invalid_token_returns_token_invalid():
"""Bad signature → 401 with code=token_invalid."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
_setup_config()
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "token_invalid"
def test_get_current_user_malformed_token_returns_token_invalid():
"""Garbage token → 401 with code=token_invalid."""
import asyncio
from fastapi import HTTPException
from app.gateway.deps import get_current_user_from_request
_setup_config()
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
with pytest.raises(HTTPException) as exc_info:
asyncio.run(get_current_user_from_request(mock_request))
assert exc_info.value.status_code == 401
detail = exc_info.value.detail
assert detail["code"] == "token_invalid"
# ── decode_token edge cases ──────────────────────────────────────────
def test_decode_token_empty_string_returns_malformed():
_setup_config()
result = decode_token("")
assert result == TokenError.MALFORMED
def test_decode_token_whitespace_returns_malformed():
_setup_config()
result = decode_token(" ")
assert result == TokenError.MALFORMED
# ── AuthConfig validation edge cases ─────────────────────────────────
def test_auth_config_missing_jwt_secret_raises():
"""AuthConfig requires jwt_secret — no default allowed."""
with pytest.raises(ValidationError):
AuthConfig()
def test_auth_config_token_expiry_zero_raises():
"""token_expiry_days must be >= 1."""
with pytest.raises(ValidationError):
AuthConfig(jwt_secret="secret", token_expiry_days=0)
def test_auth_config_token_expiry_31_raises():
"""token_expiry_days must be <= 30."""
with pytest.raises(ValidationError):
AuthConfig(jwt_secret="secret", token_expiry_days=31)
def test_auth_config_token_expiry_boundary_1_ok():
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
assert config.token_expiry_days == 1
def test_auth_config_token_expiry_boundary_30_ok():
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
assert config.token_expiry_days == 30
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
import logging
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with caplog.at_level(logging.WARNING):
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
finally:
cfg._auth_config = old
# ── CSRF middleware integration (unhappy paths) ──────────────────────
def _make_csrf_app():
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
from fastapi import HTTPException as _HTTPException
from fastapi.responses import JSONResponse as _JSONResponse
app = FastAPI()
@app.exception_handler(_HTTPException)
async def _http_exc_handler(request, exc):
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
app.add_middleware(CSRFMiddleware)
@app.post("/api/v1/test/protected")
async def protected():
return {"ok": True}
@app.post("/api/v1/auth/login/local")
async def login():
return {"ok": True}
@app.get("/api/v1/test/read")
async def read_endpoint():
return {"ok": True}
return app
def test_csrf_middleware_blocks_post_without_token():
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
client = TestClient(_make_csrf_app())
resp = client.post("/api/v1/test/protected")
assert resp.status_code == 403
assert "CSRF" in resp.json()["detail"]
assert "missing" in resp.json()["detail"].lower()
def test_csrf_middleware_blocks_post_with_mismatched_token():
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
client = TestClient(_make_csrf_app())
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
resp = client.post(
"/api/v1/test/protected",
headers={CSRF_HEADER_NAME: "token-b"},
)
assert resp.status_code == 403
assert "mismatch" in resp.json()["detail"].lower()
def test_csrf_middleware_allows_post_with_matching_token():
"""POST with matching CSRF cookie/header → 200."""
client = TestClient(_make_csrf_app())
token = secrets.token_urlsafe(64)
client.cookies.set(CSRF_COOKIE_NAME, token)
resp = client.post(
"/api/v1/test/protected",
headers={CSRF_HEADER_NAME: token},
)
assert resp.status_code == 200
def test_csrf_middleware_allows_get_without_token():
"""GET requests bypass CSRF check."""
client = TestClient(_make_csrf_app())
resp = client.get("/api/v1/test/read")
assert resp.status_code == 200
def test_csrf_middleware_exempts_login_local():
"""POST to login/local is exempt from CSRF (no token yet)."""
client = TestClient(_make_csrf_app())
resp = client.post("/api/v1/auth/login/local")
assert resp.status_code == 200
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
"""Auth endpoints should receive a CSRF cookie in response."""
client = TestClient(_make_csrf_app())
resp = client.post("/api/v1/auth/login/local")
assert CSRF_COOKIE_NAME in resp.cookies
# ── UserResponse edge cases ──────────────────────────────────────────
def test_user_response_missing_required_fields():
"""UserResponse with missing fields → ValidationError."""
from app.gateway.auth.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1") # missing email, system_role
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com") # missing system_role
def test_user_response_empty_string_role_rejected():
"""Empty string is not a valid role."""
from app.gateway.auth.models import UserResponse
with pytest.raises(ValidationError):
UserResponse(id="1", email="a@b.com", system_role="")
# ══════════════════════════════════════════════════════════════════════
# HTTP-LEVEL API CONTRACT TESTS
# ══════════════════════════════════════════════════════════════════════
def _make_auth_app():
"""Create FastAPI app with auth routes for contract testing."""
from app.gateway.app import create_app
return create_app()
def _get_auth_client():
"""Get TestClient for auth API contract tests."""
return TestClient(_make_auth_app())
def test_api_auth_me_no_cookie_returns_structured_401():
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
_setup_config()
client = _get_auth_client()
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "not_authenticated"
assert "message" in body["detail"]
def test_api_auth_me_expired_token_returns_structured_401():
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
_setup_config()
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
client = _get_auth_client()
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_expired"
def test_api_auth_me_invalid_sig_returns_structured_401():
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
_setup_config()
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
client = _get_auth_client()
client.cookies.set("access_token", token)
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "token_invalid"
def test_api_login_bad_credentials_returns_structured_401():
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
)
assert resp.status_code == 401
body = resp.json()
assert body["detail"]["code"] == "invalid_credentials"
def test_api_login_success_no_token_in_body():
"""Successful login → response body has expires_in but NOT access_token."""
_setup_config()
client = _get_auth_client()
# Register first
client.post(
"/api/v1/auth/register",
json={"email": "contract-test@test.com", "password": "securepassword123"},
)
# Login
resp = client.post(
"/api/v1/auth/login/local",
data={"username": "contract-test@test.com", "password": "securepassword123"},
)
assert resp.status_code == 200
body = resp.json()
assert "expires_in" in body
assert "access_token" not in body
# Token should be in cookie, not body
assert "access_token" in resp.cookies
def test_api_register_duplicate_returns_structured_400():
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
_setup_config()
client = _get_auth_client()
email = "dup-contract-test@test.com"
# First register
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
# Duplicate
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "password456"})
assert resp.status_code == 400
body = resp.json()
assert body["detail"]["code"] == "email_already_exists"
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
def _unique_email(prefix: str) -> str:
return f"{prefix}-{secrets.token_hex(4)}@test.com"
def _get_set_cookie_headers(resp) -> list[str]:
"""Extract all set-cookie header values from a TestClient response."""
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
def test_register_http_cookie_httponly_true_secure_false():
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("http-cookie"), "password": "password123"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" not in cookie_header.lower().replace("samesite", "")
def test_register_https_cookie_httponly_true_secure_true():
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("https-cookie"), "password": "password123"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
assert "max-age" in cookie_header.lower()
def test_login_https_sets_secure_cookie():
"""HTTPS login → access_token cookie has secure flag."""
_setup_config()
client = _get_auth_client()
email = _unique_email("https-login")
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
resp = client.post(
"/api/v1/auth/login/local",
data={"username": email, "password": "password123"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 200
cookie_header = resp.headers.get("set-cookie", "")
assert "access_token=" in cookie_header
assert "httponly" in cookie_header.lower()
assert "secure" in cookie_header.lower()
def test_csrf_cookie_secure_on_https():
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-https"), "password": "password123"},
headers={"x-forwarded-proto": "https"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
csrf_header = csrf_cookies[0]
assert "secure" in csrf_header.lower()
assert "httponly" not in csrf_header.lower()
def test_csrf_cookie_not_secure_on_http():
"""HTTP register → csrf_token cookie does NOT have secure flag."""
_setup_config()
client = _get_auth_client()
resp = client.post(
"/api/v1/auth/register",
json={"email": _unique_email("csrf-http"), "password": "password123"},
)
assert resp.status_code == 201
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
csrf_header = csrf_cookies[0]
assert "secure" not in csrf_header.lower().replace("samesite", "")
+240 -2
View File
@@ -7,12 +7,12 @@ import json
import tempfile
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.channels.base import Channel
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
from app.channels.store import ChannelStore
@@ -1718,6 +1718,159 @@ class TestFeishuChannel:
_run(go())
class TestWeComChannel:
def test_publish_ws_inbound_starts_stream_and_publishes_message(self, monkeypatch):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = WeComChannel(bus, config={})
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
monkeypatch.setitem(
__import__("sys").modules,
"aibot",
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
)
frame = {
"body": {
"msgid": "msg-1",
"from": {"userid": "user-1"},
"aibotid": "bot-1",
"chattype": "single",
}
}
files = [{"type": "image", "url": "https://example.com/image.png"}]
await channel._publish_ws_inbound(frame, "hello", files=files)
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Working on it...", False)
bus.publish_inbound.assert_awaited_once()
inbound = bus.publish_inbound.await_args.args[0]
assert inbound.channel_name == "wecom"
assert inbound.chat_id == "user-1"
assert inbound.user_id == "user-1"
assert inbound.text == "hello"
assert inbound.thread_ts == "msg-1"
assert inbound.topic_id == "user-1"
assert inbound.files == files
assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single"}
assert channel._ws_frames["msg-1"] is frame
assert channel._ws_stream_ids["msg-1"] == "stream-1"
_run(go())
def test_publish_ws_inbound_uses_configured_working_message(self, monkeypatch):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = WeComChannel(bus, config={"working_message": "Please wait..."})
channel._ws_client = SimpleNamespace(reply_stream=AsyncMock())
channel._working_message = "Please wait..."
monkeypatch.setitem(
__import__("sys").modules,
"aibot",
SimpleNamespace(generate_req_id=lambda prefix: "stream-1"),
)
frame = {
"body": {
"msgid": "msg-1",
"from": {"userid": "user-1"},
}
}
await channel._publish_ws_inbound(frame, "hello")
channel._ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "Please wait...", False)
_run(go())
def test_on_outbound_sends_attachment_before_clearing_context(self, tmp_path):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
channel = WeComChannel(bus, config={})
frame = {"body": {"msgid": "msg-1"}}
ws_client = SimpleNamespace(
reply_stream=AsyncMock(),
reply=AsyncMock(),
)
channel._ws_client = ws_client
channel._ws_frames["msg-1"] = frame
channel._ws_stream_ids["msg-1"] = "stream-1"
channel._upload_media_ws = AsyncMock(return_value="media-1")
attachment_path = tmp_path / "image.png"
attachment_path.write_bytes(b"png")
attachment = ResolvedAttachment(
virtual_path="/mnt/user-data/outputs/image.png",
actual_path=attachment_path,
filename="image.png",
mime_type="image/png",
size=attachment_path.stat().st_size,
is_image=True,
)
msg = OutboundMessage(
channel_name="wecom",
chat_id="user-1",
thread_id="thread-1",
text="done",
attachments=[attachment],
is_final=True,
thread_ts="msg-1",
)
await channel._on_outbound(msg)
ws_client.reply_stream.assert_awaited_once_with(frame, "stream-1", "done", True)
channel._upload_media_ws.assert_awaited_once_with(
media_type="image",
filename="image.png",
path=str(attachment_path),
size=attachment.size,
)
ws_client.reply.assert_awaited_once_with(frame, {"image": {"media_id": "media-1"}, "msgtype": "image"})
assert "msg-1" not in channel._ws_frames
assert "msg-1" not in channel._ws_stream_ids
_run(go())
def test_send_falls_back_to_send_message_without_thread_context(self):
from app.channels.wecom import WeComChannel
async def go():
bus = MessageBus()
channel = WeComChannel(bus, config={})
channel._ws_client = SimpleNamespace(send_message=AsyncMock())
msg = OutboundMessage(
channel_name="wecom",
chat_id="user-1",
thread_id="thread-1",
text="hello",
thread_ts=None,
)
await channel.send(msg)
channel._ws_client.send_message.assert_awaited_once_with(
"user-1",
{"msgtype": "markdown", "markdown": {"content": "hello"}},
)
_run(go())
class TestChannelService:
def test_get_status_no_channels(self):
from app.channels.service import ChannelService
@@ -1835,6 +1988,47 @@ class TestSlackSendRetry:
_run(go())
class TestSlackAllowedUsers:
def test_numeric_allowed_users_match_string_event_user_id(self):
from app.channels.slack import SlackChannel
bus = MessageBus()
bus.publish_inbound = AsyncMock()
channel = SlackChannel(
bus=bus,
config={"allowed_users": [123456]},
)
channel._loop = MagicMock()
channel._loop.is_running.return_value = True
channel._add_reaction = MagicMock()
channel._send_running_reply = MagicMock()
event = {
"user": "123456",
"text": "hello from slack",
"channel": "C123",
"ts": "1710000000.000100",
}
def submit_coro(coro, loop):
coro.close()
return MagicMock()
with patch(
"app.channels.slack.asyncio.run_coroutine_threadsafe",
side_effect=submit_coro,
) as submit:
channel._handle_message_event(event)
channel._add_reaction.assert_called_once_with("C123", "1710000000.000100", "eyes")
channel._send_running_reply.assert_called_once_with("C123", "1710000000.000100")
submit.assert_called_once()
inbound = bus.publish_inbound.call_args.args[0]
assert inbound.user_id == "123456"
assert inbound.chat_id == "C123"
assert inbound.text == "hello from slack"
def test_raises_after_all_retries_exhausted(self):
from app.channels.slack import SlackChannel
@@ -1854,6 +2048,20 @@ class TestSlackSendRetry:
_run(go())
def test_raises_runtime_error_when_no_attempts_configured(self):
from app.channels.slack import SlackChannel
async def go():
bus = MessageBus()
ch = SlackChannel(bus=bus, config={"bot_token": "xoxb-test", "app_token": "xapp-test"})
ch._web_client = MagicMock()
msg = OutboundMessage(channel_name="slack", chat_id="C123", thread_id="t1", text="hello")
with pytest.raises(RuntimeError, match="without an exception"):
await ch.send(msg, _max_retries=0)
_run(go())
# ---------------------------------------------------------------------------
# Telegram send retry tests
@@ -1912,6 +2120,36 @@ class TestTelegramSendRetry:
_run(go())
def test_raises_runtime_error_when_no_attempts_configured(self):
from app.channels.telegram import TelegramChannel
async def go():
bus = MessageBus()
ch = TelegramChannel(bus=bus, config={"bot_token": "test-token"})
ch._application = MagicMock()
msg = OutboundMessage(channel_name="telegram", chat_id="12345", thread_id="t1", text="hello")
with pytest.raises(RuntimeError, match="without an exception"):
await ch.send(msg, _max_retries=0)
_run(go())
class TestFeishuSendRetry:
def test_raises_runtime_error_when_no_attempts_configured(self):
from app.channels.feishu import FeishuChannel
async def go():
bus = MessageBus()
ch = FeishuChannel(bus=bus, config={"app_id": "id", "app_secret": "secret"})
ch._api_client = MagicMock()
msg = OutboundMessage(channel_name="feishu", chat_id="chat", thread_id="t1", text="hello")
with pytest.raises(RuntimeError, match="without an exception"):
await ch.send(msg, _max_retries=0)
_run(go())
# ---------------------------------------------------------------------------
# Telegram private-chat thread context tests
+11 -2
View File
@@ -59,18 +59,20 @@ class TestClientInit:
assert client._subagent_enabled is False
assert client._plan_mode is False
assert client._agent_name is None
assert client._available_skills is None
assert client._checkpointer is None
assert client._agent is None
def test_custom_params(self, mock_app_config):
mock_middleware = MagicMock()
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", middlewares=[mock_middleware])
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent", available_skills={"skill1", "skill2"}, middlewares=[mock_middleware])
assert c._model_name == "gpt-4"
assert c._thinking_enabled is False
assert c._subagent_enabled is True
assert c._plan_mode is True
assert c._agent_name == "test-agent"
assert c._available_skills == {"skill1", "skill2"}
assert c._middlewares == [mock_middleware]
def test_invalid_agent_name(self, mock_app_config):
@@ -394,8 +396,10 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", return_value=[]) as mock_build_middlewares,
patch("deerflow.client.apply_prompt_template", return_value="prompt") as mock_apply_prompt,
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._agent_name = "custom-agent"
client._available_skills = {"test_skill"}
client._ensure_agent(config)
assert client._agent is mock_agent
@@ -404,6 +408,7 @@ class TestEnsureAgent:
assert mock_build_middlewares.call_args.kwargs.get("agent_name") == "custom-agent"
mock_apply_prompt.assert_called_once()
assert mock_apply_prompt.call_args.kwargs.get("agent_name") == "custom-agent"
assert mock_apply_prompt.call_args.kwargs.get("available_skills") == {"test_skill"}
def test_uses_default_checkpointer_when_available(self, client):
mock_agent = MagicMock()
@@ -441,6 +446,7 @@ class TestEnsureAgent:
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
@@ -469,7 +475,7 @@ class TestEnsureAgent:
"""_ensure_agent does not recreate if config key unchanged."""
mock_agent = MagicMock()
client._agent = mock_agent
client._agent_config_key = (None, True, False, False)
client._agent_config_key = (None, True, False, False, None, None)
config = client._get_runnable_config("t1")
client._ensure_agent(config)
@@ -1276,6 +1282,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config_a)
first_agent = client._agent
@@ -1303,6 +1310,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
client._ensure_agent(config)
@@ -1327,6 +1335,7 @@ class TestScenarioAgentRecreation:
patch("deerflow.client._build_middlewares", return_value=[]),
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
patch.object(client, "_get_tools", return_value=[]),
patch("deerflow.agents.checkpointer.get_checkpointer", return_value=MagicMock()),
):
client._ensure_agent(config)
client.reset_agent()
+9
View File
@@ -439,6 +439,15 @@ class TestAgentsAPI:
assert "agent-one" in names
assert "agent-two" in names
def test_list_agents_includes_soul(self, agent_client):
agent_client.post("/api/agents", json={"name": "soul-agent", "soul": "My soul content"})
response = agent_client.get("/api/agents")
assert response.status_code == 200
agents = response.json()["agents"]
soul_agent = next(a for a in agents if a["name"] == "soul-agent")
assert soul_agent["soul"] == "My soul content"
def test_get_agent(self, agent_client):
agent_client.post("/api/agents", json={"name": "test-agent", "soul": "Hello world"})
+214
View File
@@ -0,0 +1,214 @@
"""Tests for _ensure_admin_user() in app.py.
Covers: first-boot admin creation, auto-reset on needs_setup=True,
no-op on needs_setup=False, migration, and edge cases.
"""
import asyncio
import os
from datetime import UTC, datetime, timedelta
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
import pytest
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.models import User
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
@pytest.fixture(autouse=True)
def _setup_auth_config():
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
yield
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
def _make_app_stub(store=None):
"""Minimal app-like object with state.store."""
app = SimpleNamespace()
app.state = SimpleNamespace()
app.state.store = store
return app
def _make_provider(user_count=0, admin_user=None):
p = AsyncMock()
p.count_users = AsyncMock(return_value=user_count)
p.create_user = AsyncMock(
side_effect=lambda **kw: User(
email=kw["email"],
password_hash="hashed",
system_role=kw.get("system_role", "user"),
needs_setup=kw.get("needs_setup", False),
)
)
p.get_user_by_email = AsyncMock(return_value=admin_user)
p.update_user = AsyncMock(side_effect=lambda u: u)
return p
# ── First boot: no users ─────────────────────────────────────────────────
def test_first_boot_creates_admin():
"""count_users==0 → create admin with needs_setup=True."""
provider = _make_provider(user_count=0)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_called_once()
call_kwargs = provider.create_user.call_args[1]
assert call_kwargs["email"] == "admin@deerflow.dev"
assert call_kwargs["system_role"] == "admin"
assert call_kwargs["needs_setup"] is True
assert len(call_kwargs["password"]) > 10 # random password generated
def test_first_boot_triggers_migration_if_store_present():
"""First boot with store → _migrate_orphaned_threads called."""
provider = _make_provider(user_count=0)
store = AsyncMock()
store.asearch = AsyncMock(return_value=[])
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
store.asearch.assert_called_once()
def test_first_boot_no_store_skips_migration():
"""First boot without store → no crash, migration skipped."""
provider = _make_provider(user_count=0)
app = _make_app_stub(store=None)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_called_once()
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
def test_needs_setup_true_resets_password():
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
admin = User(
email="admin@deerflow.dev",
password_hash="old-hash",
system_role="admin",
needs_setup=True,
token_version=0,
created_at=datetime.now(UTC) - timedelta(seconds=30),
)
provider = _make_provider(user_count=1, admin_user=admin)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
# Password was reset
provider.update_user.assert_called_once()
updated = provider.update_user.call_args[0][0]
assert updated.password_hash == "new-hash"
assert updated.token_version == 1
def test_needs_setup_true_consecutive_resets_increment_version():
"""Two boots with needs_setup=True → token_version increments each time."""
admin = User(
email="admin@deerflow.dev",
password_hash="hash",
system_role="admin",
needs_setup=True,
token_version=3,
created_at=datetime.now(UTC) - timedelta(seconds=30),
)
provider = _make_provider(user_count=1, admin_user=admin)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
updated = provider.update_user.call_args[0][0]
assert updated.token_version == 4
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
def test_needs_setup_false_no_reset():
"""Admin with needs_setup=False → no password reset, no update."""
admin = User(
email="admin@deerflow.dev",
password_hash="stable-hash",
system_role="admin",
needs_setup=False,
token_version=2,
)
provider = _make_provider(user_count=1, admin_user=admin)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.update_user.assert_not_called()
assert admin.password_hash == "stable-hash"
assert admin.token_version == 2
# ── Edge cases ───────────────────────────────────────────────────────────
def test_no_admin_email_found_no_crash():
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
provider = _make_provider(user_count=3, admin_user=None)
app = _make_app_stub()
with patch("app.gateway.deps.get_local_provider", return_value=provider):
from app.gateway.app import _ensure_admin_user
asyncio.run(_ensure_admin_user(app))
provider.update_user.assert_not_called()
provider.create_user.assert_not_called()
def test_migration_failure_is_non_fatal():
"""_migrate_orphaned_threads exception is caught and logged."""
provider = _make_provider(user_count=0)
store = AsyncMock()
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
app = _make_app_stub(store=store)
with patch("app.gateway.deps.get_local_provider", return_value=provider):
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
from app.gateway.app import _ensure_admin_user
# Should not raise
asyncio.run(_ensure_admin_user(app))
provider.create_user.assert_called_once()
+459
View File
@@ -0,0 +1,459 @@
"""Tests for file_conversion utilities (PR1: pymupdf4llm + asyncio.to_thread; PR2: extract_outline)."""
from __future__ import annotations
import asyncio
import sys
from types import ModuleType
from unittest.mock import MagicMock, patch
from deerflow.utils.file_conversion import (
_ASYNC_THRESHOLD_BYTES,
_MIN_CHARS_PER_PAGE,
MAX_OUTLINE_ENTRIES,
_do_convert,
_pymupdf_output_too_sparse,
convert_file_to_markdown,
extract_outline,
)
def _make_pymupdf_mock(page_count: int) -> ModuleType:
"""Return a fake *pymupdf* module whose ``open()`` reports *page_count* pages."""
mock_doc = MagicMock()
mock_doc.__len__ = MagicMock(return_value=page_count)
fake_pymupdf = ModuleType("pymupdf")
fake_pymupdf.open = MagicMock(return_value=mock_doc) # type: ignore[attr-defined]
return fake_pymupdf
def _run(coro):
loop = asyncio.new_event_loop()
try:
return loop.run_until_complete(coro)
finally:
loop.close()
# ---------------------------------------------------------------------------
# _pymupdf_output_too_sparse
# ---------------------------------------------------------------------------
class TestPymupdfOutputTooSparse:
"""Check the chars-per-page sparsity heuristic."""
def test_dense_text_pdf_not_sparse(self, tmp_path):
"""Normal text PDF: many chars per page → not sparse."""
pdf = tmp_path / "dense.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# 10 pages × 10 000 chars → 1000/page ≫ threshold
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=10)}):
result = _pymupdf_output_too_sparse("x" * 10_000, pdf)
assert result is False
def test_image_based_pdf_is_sparse(self, tmp_path):
"""Image-based PDF: near-zero chars per page → sparse."""
pdf = tmp_path / "image.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# 612 chars / 31 pages ≈ 19.7/page < _MIN_CHARS_PER_PAGE (50)
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=31)}):
result = _pymupdf_output_too_sparse("x" * 612, pdf)
assert result is True
def test_fallback_when_pymupdf_unavailable(self, tmp_path):
"""When pymupdf is not installed, fall back to absolute 200-char threshold."""
pdf = tmp_path / "broken.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# Remove pymupdf from sys.modules so the `import pymupdf` inside the
# function raises ImportError, triggering the absolute-threshold fallback.
with patch.dict(sys.modules, {"pymupdf": None}):
sparse = _pymupdf_output_too_sparse("x" * 100, pdf)
not_sparse = _pymupdf_output_too_sparse("x" * 300, pdf)
assert sparse is True
assert not_sparse is False
def test_exactly_at_threshold_is_not_sparse(self, tmp_path):
"""Chars-per-page == threshold is treated as NOT sparse (boundary inclusive)."""
pdf = tmp_path / "boundary.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
# 2 pages × _MIN_CHARS_PER_PAGE chars = exactly at threshold
with patch.dict(sys.modules, {"pymupdf": _make_pymupdf_mock(page_count=2)}):
result = _pymupdf_output_too_sparse("x" * (_MIN_CHARS_PER_PAGE * 2), pdf)
assert result is False
# ---------------------------------------------------------------------------
# _do_convert — routing logic
# ---------------------------------------------------------------------------
class TestDoConvert:
"""Verify that _do_convert routes to the right sub-converter."""
def test_non_pdf_always_uses_markitdown(self, tmp_path):
"""DOCX / XLSX / PPTX always go through MarkItDown regardless of setting."""
docx = tmp_path / "report.docx"
docx.write_bytes(b"PK fake docx")
with patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="# Markdown from MarkItDown",
) as mock_md:
result = _do_convert(docx, "auto")
mock_md.assert_called_once_with(docx)
assert result == "# Markdown from MarkItDown"
def test_pdf_auto_uses_pymupdf4llm_when_dense(self, tmp_path):
"""auto mode: use pymupdf4llm output when it's dense enough."""
pdf = tmp_path / "report.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
dense_text = "# Heading\n" + "word " * 2000 # clearly dense
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value=dense_text,
),
patch(
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
return_value=False,
),
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
):
result = _do_convert(pdf, "auto")
mock_md.assert_not_called()
assert result == dense_text
def test_pdf_auto_falls_back_when_sparse(self, tmp_path):
"""auto mode: fall back to MarkItDown when pymupdf4llm output is sparse."""
pdf = tmp_path / "scanned.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value="x" * 612, # 19.7 chars/page for 31-page doc
),
patch(
"deerflow.utils.file_conversion._pymupdf_output_too_sparse",
return_value=True,
),
patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="OCR result via MarkItDown",
) as mock_md,
):
result = _do_convert(pdf, "auto")
mock_md.assert_called_once_with(pdf)
assert result == "OCR result via MarkItDown"
def test_pdf_explicit_pymupdf4llm_skips_sparsity_check(self, tmp_path):
"""'pymupdf4llm' mode: use output as-is even if sparse."""
pdf = tmp_path / "explicit.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
sparse_text = "x" * 10 # very short
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value=sparse_text,
),
patch("deerflow.utils.file_conversion._convert_with_markitdown") as mock_md,
):
result = _do_convert(pdf, "pymupdf4llm")
mock_md.assert_not_called()
assert result == sparse_text
def test_pdf_explicit_markitdown_skips_pymupdf4llm(self, tmp_path):
"""'markitdown' mode: never attempt pymupdf4llm."""
pdf = tmp_path / "force_md.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch("deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm") as mock_pymu,
patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="MarkItDown result",
),
):
result = _do_convert(pdf, "markitdown")
mock_pymu.assert_not_called()
assert result == "MarkItDown result"
def test_pdf_auto_falls_back_when_pymupdf4llm_not_installed(self, tmp_path):
"""auto mode: if pymupdf4llm is not installed, use MarkItDown directly."""
pdf = tmp_path / "no_pymupdf.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch(
"deerflow.utils.file_conversion._convert_pdf_with_pymupdf4llm",
return_value=None, # None signals not installed
),
patch(
"deerflow.utils.file_conversion._convert_with_markitdown",
return_value="MarkItDown fallback",
) as mock_md,
):
result = _do_convert(pdf, "auto")
mock_md.assert_called_once_with(pdf)
assert result == "MarkItDown fallback"
# ---------------------------------------------------------------------------
# convert_file_to_markdown — async + file writing
# ---------------------------------------------------------------------------
class TestConvertFileToMarkdown:
def test_small_file_runs_synchronously(self, tmp_path):
"""Small files (< 1 MB) are converted in the event loop thread."""
pdf = tmp_path / "small.pdf"
pdf.write_bytes(b"%PDF-1.4 " + b"x" * 100) # well under 1 MB
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
return_value="# Small PDF",
) as mock_convert,
patch("asyncio.to_thread") as mock_thread,
):
md_path = _run(convert_file_to_markdown(pdf))
# asyncio.to_thread must NOT have been called
mock_thread.assert_not_called()
mock_convert.assert_called_once()
assert md_path == pdf.with_suffix(".md")
assert md_path.read_text() == "# Small PDF"
def test_large_file_offloaded_to_thread(self, tmp_path):
"""Large files (> 1 MB) are offloaded via asyncio.to_thread."""
pdf = tmp_path / "large.pdf"
# Write slightly more than the threshold
pdf.write_bytes(b"%PDF-1.4 " + b"x" * (_ASYNC_THRESHOLD_BYTES + 1))
async def fake_to_thread(fn, *args, **kwargs):
return fn(*args, **kwargs)
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
return_value="# Large PDF",
),
patch("asyncio.to_thread", side_effect=fake_to_thread) as mock_thread,
):
md_path = _run(convert_file_to_markdown(pdf))
mock_thread.assert_called_once()
assert md_path == pdf.with_suffix(".md")
assert md_path.read_text() == "# Large PDF"
def test_returns_none_on_conversion_error(self, tmp_path):
"""If conversion raises, return None without propagating the exception."""
pdf = tmp_path / "broken.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
side_effect=RuntimeError("conversion failed"),
),
):
result = _run(convert_file_to_markdown(pdf))
assert result is None
def test_writes_utf8_markdown_file(self, tmp_path):
"""Generated .md file is written with UTF-8 encoding."""
pdf = tmp_path / "report.pdf"
pdf.write_bytes(b"%PDF-1.4 fake")
chinese_content = "# 中文报告\n\n这是测试内容。"
with (
patch("deerflow.utils.file_conversion._get_pdf_converter", return_value="auto"),
patch(
"deerflow.utils.file_conversion._do_convert",
return_value=chinese_content,
),
):
md_path = _run(convert_file_to_markdown(pdf))
assert md_path is not None
assert md_path.read_text(encoding="utf-8") == chinese_content
# ---------------------------------------------------------------------------
# extract_outline
# ---------------------------------------------------------------------------
class TestExtractOutline:
"""Tests for extract_outline()."""
def test_empty_file_returns_empty(self, tmp_path):
"""Empty markdown file yields no outline entries."""
md = tmp_path / "empty.md"
md.write_text("", encoding="utf-8")
assert extract_outline(md) == []
def test_missing_file_returns_empty(self, tmp_path):
"""Non-existent path returns [] without raising."""
assert extract_outline(tmp_path / "nonexistent.md") == []
def test_standard_markdown_headings(self, tmp_path):
"""# / ## / ### headings are all recognised."""
md = tmp_path / "doc.md"
md.write_text(
"# Chapter One\n\nSome text.\n\n## Section 1.1\n\nMore text.\n\n### Sub 1.1.1\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 3
assert outline[0] == {"title": "Chapter One", "line": 1}
assert outline[1] == {"title": "Section 1.1", "line": 5}
assert outline[2] == {"title": "Sub 1.1.1", "line": 9}
def test_bold_sec_item_heading(self, tmp_path):
"""**ITEM N. TITLE** lines in SEC filings are recognised."""
md = tmp_path / "10k.md"
md.write_text(
"Cover page text.\n\n**ITEM 1. BUSINESS**\n\nBody.\n\n**ITEM 1A. RISK FACTORS**\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 2
assert outline[0] == {"title": "ITEM 1. BUSINESS", "line": 3}
assert outline[1] == {"title": "ITEM 1A. RISK FACTORS", "line": 7}
def test_bold_part_heading(self, tmp_path):
"""**PART I** / **PART II** headings are recognised."""
md = tmp_path / "10k.md"
md.write_text("**PART I**\n\n**PART II**\n\n**PART III**\n", encoding="utf-8")
outline = extract_outline(md)
assert len(outline) == 3
titles = [e["title"] for e in outline]
assert "PART I" in titles
assert "PART II" in titles
assert "PART III" in titles
def test_sec_cover_page_boilerplate_excluded(self, tmp_path):
"""Address lines and short cover boilerplate must NOT appear in outline."""
md = tmp_path / "8k.md"
md.write_text(
"## **UNITED STATES SECURITIES AND EXCHANGE COMMISSION**\n\n**WASHINGTON, DC 20549**\n\n**CURRENT REPORT**\n\n**SIGNATURES**\n\n**TESLA, INC.**\n\n**ITEM 2.02. RESULTS OF OPERATIONS**\n",
encoding="utf-8",
)
outline = extract_outline(md)
titles = [e["title"] for e in outline]
# Cover-page boilerplate should be excluded
assert "WASHINGTON, DC 20549" not in titles
assert "CURRENT REPORT" not in titles
assert "SIGNATURES" not in titles
assert "TESLA, INC." not in titles
# Real SEC heading must be included
assert "ITEM 2.02. RESULTS OF OPERATIONS" in titles
def test_chinese_headings_via_standard_markdown(self, tmp_path):
"""Chinese annual report headings emitted as # by pymupdf4llm are captured."""
md = tmp_path / "annual.md"
md.write_text(
"# 第一节 公司简介\n\n内容。\n\n## 第三节 管理层讨论与分析\n\n分析内容。\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 2
assert outline[0]["title"] == "第一节 公司简介"
assert outline[1]["title"] == "第三节 管理层讨论与分析"
def test_outline_capped_at_max_entries(self, tmp_path):
"""When truncated, result has MAX_OUTLINE_ENTRIES real entries + 1 sentinel."""
lines = [f"# Heading {i}" for i in range(MAX_OUTLINE_ENTRIES + 10)]
md = tmp_path / "long.md"
md.write_text("\n".join(lines), encoding="utf-8")
outline = extract_outline(md)
# Last entry is the truncation sentinel
assert outline[-1] == {"truncated": True}
# Visible entries are exactly MAX_OUTLINE_ENTRIES
visible = [e for e in outline if not e.get("truncated")]
assert len(visible) == MAX_OUTLINE_ENTRIES
def test_no_truncation_sentinel_when_under_limit(self, tmp_path):
"""Short documents produce no sentinel entry."""
lines = [f"# Heading {i}" for i in range(5)]
md = tmp_path / "short.md"
md.write_text("\n".join(lines), encoding="utf-8")
outline = extract_outline(md)
assert len(outline) == 5
assert not any(e.get("truncated") for e in outline)
def test_blank_lines_and_whitespace_ignored(self, tmp_path):
"""Blank lines between headings do not produce empty entries."""
md = tmp_path / "spaced.md"
md.write_text("\n\n# Title One\n\n\n\n# Title Two\n\n", encoding="utf-8")
outline = extract_outline(md)
assert len(outline) == 2
assert all(e["title"] for e in outline)
def test_inline_bold_not_confused_with_heading(self, tmp_path):
"""Mid-sentence bold text must not be mistaken for a heading."""
md = tmp_path / "prose.md"
md.write_text(
"This sentence has **bold words** inside it.\n\nAnother with **MULTIPLE CAPS** inline.\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert outline == []
def test_split_bold_heading_academic_paper(self, tmp_path):
"""**<num>** **<title>** lines from academic papers are recognised (Style 3)."""
md = tmp_path / "paper.md"
md.write_text(
"## **Attention Is All You Need**\n\n**1** **Introduction**\n\nBody text.\n\n**2** **Background**\n\nMore text.\n\n**3.1** **Encoder and Decoder Stacks**\n",
encoding="utf-8",
)
outline = extract_outline(md)
titles = [e["title"] for e in outline]
assert "1 Introduction" in titles
assert "2 Background" in titles
assert "3.1 Encoder and Decoder Stacks" in titles
def test_split_bold_year_columns_excluded(self, tmp_path):
"""Financial table headers like **2023** **2022** **2021** are NOT headings."""
md = tmp_path / "annual.md"
md.write_text(
"# Financial Summary\n\n**2023** **2022** **2021**\n\nRevenue 100 90 80\n",
encoding="utf-8",
)
outline = extract_outline(md)
titles = [e["title"] for e in outline]
# Only the # heading should appear, not the year-column row
assert titles == ["Financial Summary"]
def test_adjacent_bold_spans_merged_in_markdown_heading(self, tmp_path):
"""** ** artefacts inside a # heading are merged into clean plain text."""
md = tmp_path / "sec.md"
md.write_text(
"## **UNITED STATES** **SECURITIES AND EXCHANGE COMMISSION**\n\nBody text.\n",
encoding="utf-8",
)
outline = extract_outline(md)
assert len(outline) == 1
# Title must be clean — no ** ** artefacts
assert outline[0]["title"] == "UNITED STATES SECURITIES AND EXCHANGE COMMISSION"
+135 -3
View File
@@ -8,6 +8,7 @@ import pytest
from deerflow.config.acp_config import ACPAgentConfig
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig, set_extensions_config
from deerflow.tools.builtins.invoke_acp_agent_tool import (
_build_acp_mcp_servers,
_build_mcp_servers,
_build_permission_response,
_get_work_dir,
@@ -42,6 +43,43 @@ def test_build_mcp_servers_filters_disabled_and_maps_transports():
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_acp_mcp_servers_formats_list_payload():
set_extensions_config(ExtensionsConfig(mcp_servers={"stale": McpServerConfig(enabled=True, type="stdio", command="echo")}, skills={}))
fresh_config = ExtensionsConfig(
mcp_servers={
"stdio": McpServerConfig(enabled=True, type="stdio", command="npx", args=["srv"], env={"FOO": "bar"}),
"http": McpServerConfig(enabled=True, type="http", url="https://example.com/mcp", headers={"Authorization": "Bearer token"}),
"disabled": McpServerConfig(enabled=False, type="stdio", command="echo"),
},
skills={},
)
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: fresh_config),
)
try:
assert _build_acp_mcp_servers() == [
{
"name": "stdio",
"type": "stdio",
"command": "npx",
"args": ["srv"],
"env": [{"name": "FOO", "value": "bar"}],
},
{
"name": "http",
"type": "http",
"url": "https://example.com/mcp",
"headers": [{"name": "Authorization", "value": "Bearer token"}],
},
]
finally:
monkeypatch.undo()
set_extensions_config(ExtensionsConfig(mcp_servers={}, skills={}))
def test_build_permission_response_prefers_allow_once():
response = _build_permission_response(
[
@@ -251,9 +289,15 @@ async def test_invoke_acp_agent_uses_fixed_acp_workspace(monkeypatch, tmp_path):
assert captured["spawn"] == {"cmd": "codex-acp", "args": ["--json"], "cwd": expected_cwd}
assert captured["new_session"] == {
"cwd": expected_cwd,
"mcp_servers": {
"github": {"transport": "stdio", "command": "npx", "args": ["github-mcp"]},
},
"mcp_servers": [
{
"name": "github",
"type": "stdio",
"command": "npx",
"args": ["github-mcp"],
"env": [],
}
],
"model": "gpt-5-codex",
}
assert captured["prompt"] == {
@@ -448,6 +492,94 @@ async def test_invoke_acp_agent_passes_env_to_spawn(monkeypatch, tmp_path):
assert captured["env"] == {"OPENAI_API_KEY": "sk-from-env", "FOO": "bar"}
@pytest.mark.anyio
async def test_invoke_acp_agent_skips_invalid_mcp_servers(monkeypatch, tmp_path, caplog):
"""Invalid MCP config should be logged and skipped instead of failing ACP invocation."""
from deerflow.config import paths as paths_module
monkeypatch.setattr(paths_module, "get_paths", lambda: paths_module.Paths(base_dir=tmp_path))
monkeypatch.setattr(
"deerflow.tools.builtins.invoke_acp_agent_tool._build_acp_mcp_servers",
lambda: (_ for _ in ()).throw(ValueError("missing command")),
)
captured: dict[str, object] = {}
class DummyClient:
def __init__(self) -> None:
self._chunks: list[str] = []
@property
def collected_text(self) -> str:
return ""
async def session_update(self, session_id, update, **kwargs):
pass
async def request_permission(self, options, session_id, tool_call, **kwargs):
raise AssertionError("should not be called")
class DummyConn:
async def initialize(self, **kwargs):
pass
async def new_session(self, **kwargs):
captured["new_session"] = kwargs
return SimpleNamespace(session_id="s1")
async def prompt(self, **kwargs):
pass
class DummyProcessContext:
def __init__(self, client, cmd, *args, env=None, cwd=None):
captured["spawn"] = {"cmd": cmd, "args": list(args), "env": env, "cwd": cwd}
async def __aenter__(self):
return DummyConn(), object()
async def __aexit__(self, exc_type, exc, tb):
return False
class DummyRequestError(Exception):
@staticmethod
def method_not_found(method):
return DummyRequestError(method)
monkeypatch.setitem(
sys.modules,
"acp",
SimpleNamespace(
PROTOCOL_VERSION="2026-03-24",
Client=DummyClient,
RequestError=DummyRequestError,
spawn_agent_process=lambda client, cmd, *args, env=None, cwd: DummyProcessContext(client, cmd, *args, env=env, cwd=cwd),
text_block=lambda text: {"type": "text", "text": text},
),
)
monkeypatch.setitem(
sys.modules,
"acp.schema",
SimpleNamespace(
ClientCapabilities=lambda: {},
Implementation=lambda **kwargs: kwargs,
TextContentBlock=type("TextContentBlock", (), {"__init__": lambda self, text: setattr(self, "text", text)}),
),
)
tool = build_invoke_acp_agent_tool({"codex": ACPAgentConfig(command="codex-acp", description="Codex CLI")})
caplog.set_level("WARNING")
try:
await tool.coroutine(agent="codex", prompt="Do something")
finally:
sys.modules.pop("acp", None)
sys.modules.pop("acp.schema", None)
assert captured["new_session"]["mcp_servers"] == []
assert "continuing without MCP servers" in caplog.text
assert "missing command" in caplog.text
@pytest.mark.anyio
async def test_invoke_acp_agent_passes_none_env_when_not_configured(monkeypatch, tmp_path):
"""When env is empty, None is passed to spawn_agent_process (subprocess inherits parent env)."""
+312
View File
@@ -0,0 +1,312 @@
"""Tests for LangGraph Server auth handler (langgraph_auth.py).
Validates that the LangGraph auth layer enforces the same rules as Gateway:
cookie → JWT decode → DB lookup → token_version check → owner filter
"""
import asyncio
import os
from datetime import timedelta
from pathlib import Path
from types import SimpleNamespace
from unittest.mock import AsyncMock, patch
from uuid import uuid4
import pytest
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
from langgraph_sdk import Auth
from app.gateway.auth.config import AuthConfig, set_auth_config
from app.gateway.auth.jwt import create_access_token, decode_token
from app.gateway.auth.models import User
from app.gateway.langgraph_auth import add_owner_filter, authenticate
# ── Helpers ───────────────────────────────────────────────────────────────
_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32"
@pytest.fixture(autouse=True)
def _setup_auth_config():
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
yield
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
def _req(cookies=None, method="GET", headers=None):
return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {})
def _user(user_id=None, token_version=0):
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
def _mock_provider(user=None):
p = AsyncMock()
p.get_user = AsyncMock(return_value=user)
return p
# ── @auth.authenticate ───────────────────────────────────────────────────
def test_no_cookie_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req()))
assert exc.value.status_code == 401
assert "Not authenticated" in str(exc.value.detail)
def test_invalid_jwt_raises_401():
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": "garbage"})))
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
def test_expired_jwt_raises_401():
token = create_access_token("user-1", expires_delta=timedelta(seconds=-1))
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
assert exc.value.status_code == 401
def test_user_not_found_raises_401():
token = create_access_token("ghost")
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
assert exc.value.status_code == 401
assert "User not found" in str(exc.value.detail)
def test_token_version_mismatch_raises_401():
user = _user(token_version=2)
token = create_access_token(str(user.id), token_version=1)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": token})))
assert exc.value.status_code == 401
assert "revoked" in str(exc.value.detail).lower()
def test_valid_token_returns_user_id():
user = _user(token_version=0)
token = create_access_token(str(user.id), token_version=0)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": token})))
assert result == str(user.id)
def test_valid_token_matching_version():
user = _user(token_version=5)
token = create_access_token(str(user.id), token_version=5)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": token})))
assert result == str(user.id)
# ── @auth.authenticate edge cases ────────────────────────────────────────
def test_provider_exception_propagates():
"""Provider raises → should not be swallowed silently."""
token = create_access_token("user-1")
p = AsyncMock()
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
with pytest.raises(RuntimeError, match="DB down"):
asyncio.run(authenticate(_req({"access_token": token})))
def test_jwt_missing_ver_defaults_to_zero():
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
import jwt as pyjwt
uid = str(uuid4())
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
user = _user(user_id=uid, token_version=0)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
result = asyncio.run(authenticate(_req({"access_token": raw})))
assert result == uid
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
import jwt as pyjwt
uid = str(uuid4())
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
user = _user(user_id=uid, token_version=1)
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": raw})))
assert exc.value.status_code == 401
def test_wrong_secret_raises_401():
"""Token signed with different secret → 401."""
import jwt as pyjwt
raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256")
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req({"access_token": raw})))
assert exc.value.status_code == 401
# ── @auth.on (owner filter) ──────────────────────────────────────────────
class _FakeUser:
"""Minimal BaseUser-compatible object without langgraph_api.config dependency."""
def __init__(self, identity: str):
self.identity = identity
self.is_authenticated = True
self.display_name = identity
def _make_ctx(user_id):
return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[])
def test_filter_injects_user_id():
value = {}
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
assert value["metadata"]["user_id"] == "user-a"
def test_filter_preserves_existing_metadata():
value = {"metadata": {"title": "hello"}}
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
assert value["metadata"]["user_id"] == "user-a"
assert value["metadata"]["title"] == "hello"
def test_filter_returns_user_id_dict():
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
assert result == {"user_id": "user-x"}
def test_filter_read_write_consistency():
value = {}
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
assert value["metadata"]["user_id"] == filter_dict["user_id"]
def test_different_users_different_filters():
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
assert f_a["user_id"] != f_b["user_id"]
def test_filter_overrides_conflicting_user_id():
"""If value already has a different user_id in metadata, it gets overwritten."""
value = {"metadata": {"user_id": "attacker"}}
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
assert value["metadata"]["user_id"] == "real-owner"
def test_filter_with_empty_metadata():
"""Explicit empty metadata dict is fine."""
value = {"metadata": {}}
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
assert value["metadata"]["user_id"] == "user-z"
assert result == {"user_id": "user-z"}
# ── Gateway parity ───────────────────────────────────────────────────────
def test_shared_jwt_secret():
token = create_access_token("user-1", token_version=3)
payload = decode_token(token)
from app.gateway.auth.errors import TokenError
assert not isinstance(payload, TokenError)
assert payload.sub == "user-1"
assert payload.ver == 3
def test_langgraph_json_has_auth_path():
import json
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
assert "auth" in config
assert "langgraph_auth" in config["auth"]["path"]
def test_auth_handler_has_both_layers():
from app.gateway.langgraph_auth import auth
assert auth._authenticate_handler is not None
assert len(auth._global_handlers) == 1
# ── CSRF in LangGraph auth ──────────────────────────────────────────────
def test_csrf_get_no_check():
"""GET requests skip CSRF — should proceed to JWT validation."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="GET")))
# Rejected by missing cookie, NOT by CSRF
assert exc.value.status_code == 401
assert "Not authenticated" in str(exc.value.detail)
def test_csrf_post_missing_token():
"""POST without CSRF token → 403."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"})))
assert exc.value.status_code == 403
assert "CSRF token missing" in str(exc.value.detail)
def test_csrf_post_mismatched_token():
"""POST with mismatched CSRF tokens → 403."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(
authenticate(
_req(
method="POST",
cookies={"access_token": "some-jwt", "csrf_token": "real-token"},
headers={"x-csrf-token": "wrong-token"},
)
)
)
assert exc.value.status_code == 403
assert "mismatch" in str(exc.value.detail)
def test_csrf_post_matching_token_proceeds_to_jwt():
"""POST with matching CSRF tokens passes CSRF check, then fails on JWT."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(
authenticate(
_req(
method="POST",
cookies={"access_token": "garbage", "csrf_token": "same-token"},
headers={"x-csrf-token": "same-token"},
)
)
)
# Past CSRF, rejected by JWT decode
assert exc.value.status_code == 401
assert "Token error" in str(exc.value.detail)
def test_csrf_put_requires_token():
"""PUT also requires CSRF."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"})))
assert exc.value.status_code == 403
def test_csrf_delete_requires_token():
"""DELETE also requires CSRF."""
with pytest.raises(Auth.exceptions.HTTPException) as exc:
asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"})))
assert exc.value.status_code == 403
@@ -0,0 +1,388 @@
import errno
from types import SimpleNamespace
from unittest.mock import patch
import pytest
from deerflow.sandbox.local.local_sandbox import LocalSandbox, PathMapping
from deerflow.sandbox.local.local_sandbox_provider import LocalSandboxProvider
class TestPathMapping:
def test_path_mapping_dataclass(self):
mapping = PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True)
assert mapping.container_path == "/mnt/skills"
assert mapping.local_path == "/home/user/skills"
assert mapping.read_only is True
def test_path_mapping_defaults_to_false(self):
mapping = PathMapping(container_path="/mnt/data", local_path="/home/user/data")
assert mapping.read_only is False
class TestLocalSandboxPathResolution:
def test_resolve_path_exact_match(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
],
)
resolved = sandbox._resolve_path("/mnt/skills")
assert resolved == "/home/user/skills"
def test_resolve_path_nested_path(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
],
)
resolved = sandbox._resolve_path("/mnt/skills/agent/prompt.py")
assert resolved == "/home/user/skills/agent/prompt.py"
def test_resolve_path_no_mapping(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
],
)
resolved = sandbox._resolve_path("/mnt/other/file.txt")
assert resolved == "/mnt/other/file.txt"
def test_resolve_path_longest_prefix_first(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills"),
PathMapping(container_path="/mnt", local_path="/var/mnt"),
],
)
resolved = sandbox._resolve_path("/mnt/skills/file.py")
# Should match /mnt/skills first (longer prefix)
assert resolved == "/home/user/skills/file.py"
def test_reverse_resolve_path_exact_match(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
],
)
resolved = sandbox._reverse_resolve_path(str(skills_dir))
assert resolved == "/mnt/skills"
def test_reverse_resolve_path_nested(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
file_path = skills_dir / "agent" / "prompt.py"
file_path.parent.mkdir()
file_path.write_text("test")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir)),
],
)
resolved = sandbox._reverse_resolve_path(str(file_path))
assert resolved == "/mnt/skills/agent/prompt.py"
class TestReadOnlyPath:
def test_is_read_only_true(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
],
)
assert sandbox._is_read_only_path("/home/user/skills/file.py") is True
def test_is_read_only_false_for_writable(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path="/home/user/data", read_only=False),
],
)
assert sandbox._is_read_only_path("/home/user/data/file.txt") is False
def test_is_read_only_false_for_unmapped_path(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
],
)
# Path not under any mapping
assert sandbox._is_read_only_path("/tmp/other/file.txt") is False
def test_is_read_only_true_for_exact_match(self):
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path="/home/user/skills", read_only=True),
],
)
assert sandbox._is_read_only_path("/home/user/skills") is True
def test_write_file_blocked_on_read_only(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
],
)
# Skills dir is read-only, write should be blocked
with pytest.raises(OSError) as exc_info:
sandbox.write_file("/mnt/skills/new_file.py", "content")
assert exc_info.value.errno == errno.EROFS
def test_write_file_allowed_on_writable_mount(self, tmp_path):
data_dir = tmp_path / "data"
data_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
],
)
sandbox.write_file("/mnt/data/file.txt", "content")
assert (data_dir / "file.txt").read_text() == "content"
def test_update_file_blocked_on_read_only(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
existing_file = skills_dir / "existing.py"
existing_file.write_bytes(b"original")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
],
)
with pytest.raises(OSError) as exc_info:
sandbox.update_file("/mnt/skills/existing.py", b"updated")
assert exc_info.value.errno == errno.EROFS
class TestMultipleMounts:
def test_multiple_read_write_mounts(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
data_dir = tmp_path / "data"
data_dir.mkdir()
external_dir = tmp_path / "external"
external_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/skills", local_path=str(skills_dir), read_only=True),
PathMapping(container_path="/mnt/data", local_path=str(data_dir), read_only=False),
PathMapping(container_path="/mnt/external", local_path=str(external_dir), read_only=True),
],
)
# Skills is read-only
with pytest.raises(OSError):
sandbox.write_file("/mnt/skills/file.py", "content")
# Data is writable
sandbox.write_file("/mnt/data/file.txt", "data content")
assert (data_dir / "file.txt").read_text() == "data content"
# External is read-only
with pytest.raises(OSError):
sandbox.write_file("/mnt/external/file.txt", "content")
def test_nested_mounts_writable_under_readonly(self, tmp_path):
"""A writable mount nested under a read-only mount should allow writes."""
ro_dir = tmp_path / "ro"
ro_dir.mkdir()
rw_dir = ro_dir / "writable"
rw_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/repo", local_path=str(ro_dir), read_only=True),
PathMapping(container_path="/mnt/repo/writable", local_path=str(rw_dir), read_only=False),
],
)
# Parent mount is read-only
with pytest.raises(OSError):
sandbox.write_file("/mnt/repo/file.txt", "content")
# Nested writable mount should allow writes
sandbox.write_file("/mnt/repo/writable/file.txt", "content")
assert (rw_dir / "file.txt").read_text() == "content"
def test_execute_command_path_replacement(self, tmp_path, monkeypatch):
data_dir = tmp_path / "data"
data_dir.mkdir()
test_file = data_dir / "test.txt"
test_file.write_text("hello")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path=str(data_dir)),
],
)
# Mock subprocess to capture the resolved command
captured = {}
original_run = __import__("subprocess").run
def mock_run(*args, **kwargs):
if len(args) > 0:
captured["command"] = args[0]
return original_run(*args, **kwargs)
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.subprocess.run", mock_run)
monkeypatch.setattr("deerflow.sandbox.local.local_sandbox.LocalSandbox._get_shell", lambda self: "/bin/sh")
sandbox.execute_command("cat /mnt/data/test.txt")
# Verify the command received the resolved local path
assert str(data_dir) in captured.get("command", "")
def test_reverse_resolve_path_does_not_match_partial_prefix(self, tmp_path):
foo_dir = tmp_path / "foo"
foo_dir.mkdir()
foobar_dir = tmp_path / "foobar"
foobar_dir.mkdir()
target = foobar_dir / "file.txt"
target.write_text("test")
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/foo", local_path=str(foo_dir)),
],
)
resolved = sandbox._reverse_resolve_path(str(target))
assert resolved == str(target.resolve())
def test_reverse_resolve_paths_in_output_supports_backslash_separator(self, tmp_path):
mount_dir = tmp_path / "mount"
mount_dir.mkdir()
sandbox = LocalSandbox(
"test",
[
PathMapping(container_path="/mnt/data", local_path=str(mount_dir)),
],
)
output = f"Copied: {mount_dir}\\file.txt"
masked = sandbox._reverse_resolve_paths_in_output(output)
assert "/mnt/data/file.txt" in masked
assert str(mount_dir) not in masked
class TestLocalSandboxProviderMounts:
def test_setup_path_mappings_uses_configured_skills_container_path_as_reserved_prefix(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(custom_dir), container_path="/custom-skills/nested", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/custom-skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/custom-skills"]
def test_setup_path_mappings_skips_relative_host_path(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path="relative/path", container_path="/mnt/data", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
def test_setup_path_mappings_skips_non_absolute_container_path(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(custom_dir), container_path="mnt/data", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills"]
def test_setup_path_mappings_normalizes_container_path_trailing_slash(self, tmp_path):
skills_dir = tmp_path / "skills"
skills_dir.mkdir()
custom_dir = tmp_path / "custom"
custom_dir.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
sandbox_config = SandboxConfig(
use="deerflow.sandbox.local:LocalSandboxProvider",
mounts=[
VolumeMountConfig(host_path=str(custom_dir), container_path="/mnt/data/", read_only=False),
],
)
config = SimpleNamespace(
skills=SimpleNamespace(container_path="/mnt/skills", get_skills_path=lambda: skills_dir),
sandbox=sandbox_config,
)
with patch("deerflow.config.get_app_config", return_value=config):
provider = LocalSandboxProvider()
assert [m.container_path for m in provider._path_mappings] == ["/mnt/skills", "/mnt/data"]
+119 -2
View File
@@ -1,5 +1,6 @@
"""Tests for LoopDetectionMiddleware."""
import copy
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
@@ -19,8 +20,13 @@ def _make_runtime(thread_id="test-thread"):
def _make_state(tool_calls=None, content=""):
"""Build a minimal AgentState dict with an AIMessage."""
msg = AIMessage(content=content, tool_calls=tool_calls or [])
"""Build a minimal AgentState dict with an AIMessage.
Deep-copies *content* when it is mutable (e.g. list) so that
successive calls never share the same object reference.
"""
safe_content = copy.deepcopy(content) if isinstance(content, list) else content
msg = AIMessage(content=safe_content, tool_calls=tool_calls or [])
return {"messages": [msg]}
@@ -229,3 +235,114 @@ class TestLoopDetection:
mw._apply(_make_state(tool_calls=call), runtime)
assert "default" in mw._history
class TestAppendText:
"""Unit tests for LoopDetectionMiddleware._append_text."""
def test_none_content_returns_text(self):
result = LoopDetectionMiddleware._append_text(None, "hello")
assert result == "hello"
def test_str_content_concatenates(self):
result = LoopDetectionMiddleware._append_text("existing", "appended")
assert result == "existing\n\nappended"
def test_empty_str_content_concatenates(self):
result = LoopDetectionMiddleware._append_text("", "appended")
assert result == "\n\nappended"
def test_list_content_appends_text_block(self):
"""List content (e.g. Anthropic thinking mode) should get a new text block."""
content = [
{"type": "thinking", "text": "Let me think..."},
{"type": "text", "text": "Here is my answer"},
]
result = LoopDetectionMiddleware._append_text(content, "stop msg")
assert isinstance(result, list)
assert len(result) == 3
assert result[0] == content[0]
assert result[1] == content[1]
assert result[2] == {"type": "text", "text": "\n\nstop msg"}
def test_empty_list_content_appends_text_block(self):
result = LoopDetectionMiddleware._append_text([], "stop msg")
assert isinstance(result, list)
assert len(result) == 1
assert result[0] == {"type": "text", "text": "\n\nstop msg"}
def test_unexpected_type_coerced_to_str(self):
"""Unexpected content types should be coerced to str as a fallback."""
result = LoopDetectionMiddleware._append_text(42, "stop msg")
assert isinstance(result, str)
assert result == "42\n\nstop msg"
def test_list_content_not_mutated_in_place(self):
"""_append_text must not modify the original list."""
original = [{"type": "text", "text": "hello"}]
result = LoopDetectionMiddleware._append_text(original, "appended")
assert len(original) == 1 # original unchanged
assert len(result) == 2 # new list has the appended block
class TestHardStopWithListContent:
"""Regression tests: hard stop must not crash when AIMessage.content is a list."""
def test_hard_stop_with_list_content(self):
"""Hard stop on list content should not raise TypeError (regression)."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
# Build state with list content (e.g. Anthropic thinking mode)
list_content = [
{"type": "thinking", "text": "Let me think..."},
{"type": "text", "text": "I'll run ls"},
]
for _ in range(3):
mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
# Fourth call triggers hard stop — must not raise TypeError
result = mw._apply(_make_state(tool_calls=call, content=list_content), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg, AIMessage)
assert msg.tool_calls == []
# Content should remain a list with the stop message appended
assert isinstance(msg.content, list)
assert len(msg.content) == 3
assert msg.content[2]["type"] == "text"
assert _HARD_STOP_MSG in msg.content[2]["text"]
def test_hard_stop_with_none_content(self):
"""Hard stop on None content should produce a plain string."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call), runtime)
# Fourth call with default empty-string content
result = mw._apply(_make_state(tool_calls=call), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg.content, str)
assert _HARD_STOP_MSG in msg.content
def test_hard_stop_with_str_content(self):
"""Hard stop on str content should concatenate the stop message."""
mw = LoopDetectionMiddleware(warn_threshold=2, hard_limit=4)
runtime = _make_runtime()
call = [_bash_call("ls")]
for _ in range(3):
mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
result = mw._apply(_make_state(tool_calls=call, content="thinking..."), runtime)
assert result is not None
msg = result["messages"][0]
assert isinstance(msg.content, str)
assert msg.content.startswith("thinking...")
assert _HARD_STOP_MSG in msg.content
@@ -154,3 +154,22 @@ def test_format_memory_renders_correction_without_source_error_normally() -> Non
assert "Use make dev for local development." in result
assert "avoid:" not in result
def test_format_memory_includes_long_term_background() -> None:
"""longTermBackground in history must be injected into the prompt."""
memory_data = {
"user": {},
"history": {
"recentMonths": {"summary": "Recent activity summary"},
"earlierContext": {"summary": "Earlier context summary"},
"longTermBackground": {"summary": "Core expertise in distributed systems"},
},
"facts": [],
}
result = format_memory_for_injection(memory_data, max_tokens=2000)
assert "Background: Core expertise in distributed systems" in result
assert "Recent: Recent activity summary" in result
assert "Earlier: Earlier context summary" in result
+41
View File
@@ -47,4 +47,45 @@ def test_process_queue_forwards_correction_flag_to_updater() -> None:
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=True,
reinforcement_detected=False,
)
def test_queue_add_preserves_existing_reinforcement_flag_for_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["first"], reinforcement_detected=True)
queue.add(thread_id="thread-1", messages=["second"], reinforcement_detected=False)
assert len(queue._queue) == 1
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].reinforcement_detected is True
def test_process_queue_forwards_reinforcement_flag_to_updater() -> None:
queue = MemoryUpdateQueue()
queue._queue = [
ConversationContext(
thread_id="thread-1",
messages=["conversation"],
agent_name="lead_agent",
reinforcement_detected=True,
)
]
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater):
queue._process_queue()
mock_updater.update_memory.assert_called_once_with(
messages=["conversation"],
thread_id="thread-1",
agent_name="lead_agent",
correction_detected=False,
reinforcement_detected=True,
)
+153
View File
@@ -619,3 +619,156 @@ class TestUpdateMemoryStructuredResponse:
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" not in prompt
class TestFactDeduplicationCaseInsensitive:
"""Tests that fact deduplication is case-insensitive."""
def test_duplicate_fact_different_case_not_stored(self):
updater = MemoryUpdater()
current_memory = _make_memory(
facts=[
{
"id": "fact_1",
"content": "User prefers Python",
"category": "preference",
"confidence": 0.9,
"createdAt": "2026-01-01T00:00:00Z",
"source": "thread-a",
},
]
)
# Same fact with different casing should be treated as duplicate
update_data = {
"factsToRemove": [],
"newFacts": [
{"content": "user prefers python", "category": "preference", "confidence": 0.95},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
# Should still have only 1 fact (duplicate rejected)
assert len(result["facts"]) == 1
assert result["facts"][0]["content"] == "User prefers Python"
def test_unique_fact_different_case_and_content_stored(self):
updater = MemoryUpdater()
current_memory = _make_memory(
facts=[
{
"id": "fact_1",
"content": "User prefers Python",
"category": "preference",
"confidence": 0.9,
"createdAt": "2026-01-01T00:00:00Z",
"source": "thread-a",
},
]
)
update_data = {
"factsToRemove": [],
"newFacts": [
{"content": "User prefers Go", "category": "preference", "confidence": 0.85},
],
}
with patch(
"deerflow.agents.memory.updater.get_memory_config",
return_value=_memory_config(max_facts=100, fact_confidence_threshold=0.7),
):
result = updater._apply_updates(current_memory, update_data, thread_id="thread-b")
assert len(result["facts"]) == 2
class TestReinforcementHint:
"""Tests that reinforcement_detected injects the correct hint into the prompt."""
@staticmethod
def _make_mock_model(json_response: str):
model = MagicMock()
response = MagicMock()
response.content = f"```json\n{json_response}\n```"
model.invoke.return_value = response
return model
def test_reinforcement_hint_injected_when_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Yes, exactly! That's what I needed."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Great to hear!"
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" in prompt
def test_reinforcement_hint_absent_when_not_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "Tell me more."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Sure."
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], reinforcement_detected=False)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Positive reinforcement signals were detected" not in prompt
def test_both_hints_present_when_both_detected(self):
updater = MemoryUpdater()
valid_json = '{"user": {}, "history": {}, "newFacts": [], "factsToRemove": []}'
model = self._make_mock_model(valid_json)
with (
patch.object(updater, "_get_model", return_value=model),
patch("deerflow.agents.memory.updater.get_memory_config", return_value=_memory_config(enabled=True)),
patch("deerflow.agents.memory.updater.get_memory_data", return_value=_make_memory()),
patch("deerflow.agents.memory.updater.get_memory_storage", return_value=MagicMock(save=MagicMock(return_value=True))),
):
msg = MagicMock()
msg.type = "human"
msg.content = "No wait, that's wrong. Actually yes, exactly right."
ai_msg = MagicMock()
ai_msg.type = "ai"
ai_msg.content = "Got it."
ai_msg.tool_calls = []
result = updater.update_memory([msg, ai_msg], correction_detected=True, reinforcement_detected=True)
assert result is True
prompt = model.invoke.call_args[0][0]
assert "Explicit correction signals were detected" in prompt
assert "Positive reinforcement signals were detected" in prompt
+71 -1
View File
@@ -10,7 +10,7 @@ persisting in long-term memory:
from langchain_core.messages import AIMessage, HumanMessage, ToolMessage
from deerflow.agents.memory.updater import _strip_upload_mentions_from_memory
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction
from deerflow.agents.middlewares.memory_middleware import _filter_messages_for_memory, detect_correction, detect_reinforcement
# ---------------------------------------------------------------------------
# Helpers
@@ -270,3 +270,73 @@ class TestStripUploadMentionsFromMemory:
mem = {"user": {}, "history": {}, "facts": []}
result = _strip_upload_mentions_from_memory(mem)
assert result == {"user": {}, "history": {}, "facts": []}
# ===========================================================================
# detect_reinforcement
# ===========================================================================
class TestDetectReinforcement:
def test_detects_english_reinforcement_signal(self):
msgs = [
_human("Can you summarise it in bullet points?"),
_ai("Here are the key points: ..."),
_human("Yes, exactly! That's what I needed."),
_ai("Glad it helped."),
]
assert detect_reinforcement(msgs) is True
def test_detects_perfect_signal(self):
msgs = [
_human("Write it more concisely."),
_ai("Here is the concise version."),
_human("Perfect."),
_ai("Great!"),
]
assert detect_reinforcement(msgs) is True
def test_detects_chinese_reinforcement_signal(self):
msgs = [
_human("帮我用要点来总结"),
_ai("好的,要点如下:..."),
_human("完全正确,就是这个意思"),
_ai("很高兴能帮到你"),
]
assert detect_reinforcement(msgs) is True
def test_returns_false_without_signal(self):
msgs = [
_human("What does this function do?"),
_ai("It processes the input data."),
_human("Can you show me an example?"),
]
assert detect_reinforcement(msgs) is False
def test_only_checks_recent_messages(self):
# Reinforcement signal buried beyond the -6 window should not trigger
msgs = [
_human("Yes, exactly right."),
_ai("Noted."),
_human("Let's discuss tests."),
_ai("Sure."),
_human("What about linting?"),
_ai("Use ruff."),
_human("And formatting?"),
_ai("Use make format."),
]
assert detect_reinforcement(msgs) is False
def test_does_not_conflict_with_correction(self):
# A message can trigger correction but not reinforcement
msgs = [
_human("That's wrong, try again."),
_ai("Corrected."),
]
assert detect_reinforcement(msgs) is False
+393
View File
@@ -0,0 +1,393 @@
from types import SimpleNamespace
from unittest.mock import patch
from deerflow.community.aio_sandbox.aio_sandbox import AioSandbox
from deerflow.sandbox.local.local_sandbox import LocalSandbox
from deerflow.sandbox.search import GrepMatch, find_glob_matches, find_grep_matches
from deerflow.sandbox.tools import glob_tool, grep_tool
def _make_runtime(tmp_path):
workspace = tmp_path / "workspace"
uploads = tmp_path / "uploads"
outputs = tmp_path / "outputs"
workspace.mkdir()
uploads.mkdir()
outputs.mkdir()
return SimpleNamespace(
state={
"sandbox": {"sandbox_id": "local"},
"thread_data": {
"workspace_path": str(workspace),
"uploads_path": str(uploads),
"outputs_path": str(outputs),
},
},
context={"thread_id": "thread-1"},
)
def test_glob_tool_returns_virtual_paths_and_ignores_common_dirs(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "app.py").write_text("print('hi')\n", encoding="utf-8")
(workspace / "pkg").mkdir()
(workspace / "pkg" / "util.py").write_text("print('util')\n", encoding="utf-8")
(workspace / "node_modules").mkdir()
(workspace / "node_modules" / "skip.py").write_text("ignored\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = glob_tool.func(
runtime=runtime,
description="find python files",
pattern="**/*.py",
path="/mnt/user-data/workspace",
)
assert "/mnt/user-data/workspace/app.py" in result
assert "/mnt/user-data/workspace/pkg/util.py" in result
assert "node_modules" not in result
assert str(workspace) not in result
def test_glob_tool_supports_skills_virtual_paths(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
skills_dir = tmp_path / "skills"
(skills_dir / "public" / "demo").mkdir(parents=True)
(skills_dir / "public" / "demo" / "SKILL.md").write_text("# Demo\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
with (
patch("deerflow.sandbox.tools._get_skills_container_path", return_value="/mnt/skills"),
patch("deerflow.sandbox.tools._get_skills_host_path", return_value=str(skills_dir)),
):
result = glob_tool.func(
runtime=runtime,
description="find skills",
pattern="**/SKILL.md",
path="/mnt/skills",
)
assert "/mnt/skills/public/demo/SKILL.md" in result
assert str(skills_dir) not in result
def test_grep_tool_filters_by_glob_and_skips_binary_files(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "main.py").write_text("TODO = 'ship it'\nprint(TODO)\n", encoding="utf-8")
(workspace / "notes.txt").write_text("TODO in txt should be filtered\n", encoding="utf-8")
(workspace / "image.bin").write_bytes(b"\0binary TODO")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = grep_tool.func(
runtime=runtime,
description="find todo references",
pattern="TODO",
path="/mnt/user-data/workspace",
glob="**/*.py",
)
assert "/mnt/user-data/workspace/main.py:1: TODO = 'ship it'" in result
assert "notes.txt" not in result
assert "image.bin" not in result
assert str(workspace) not in result
def test_grep_tool_truncates_results(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "main.py").write_text("TODO one\nTODO two\nTODO three\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
# Prevent config.yaml tool config from overriding the caller-supplied max_results=2.
monkeypatch.setattr("deerflow.sandbox.tools.get_app_config", lambda: SimpleNamespace(get_tool_config=lambda name: None))
result = grep_tool.func(
runtime=runtime,
description="limit matches",
pattern="TODO",
path="/mnt/user-data/workspace",
max_results=2,
)
assert "Found 2 matches under /mnt/user-data/workspace (showing first 2)" in result
assert "TODO one" in result
assert "TODO two" in result
assert "TODO three" not in result
assert "Results truncated." in result
def test_glob_tool_include_dirs_filters_nested_ignored_paths(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "src").mkdir()
(workspace / "src" / "main.py").write_text("x\n", encoding="utf-8")
(workspace / "node_modules").mkdir()
(workspace / "node_modules" / "lib").mkdir()
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = glob_tool.func(
runtime=runtime,
description="find dirs",
pattern="**",
path="/mnt/user-data/workspace",
include_dirs=True,
)
assert "src" in result
assert "node_modules" not in result
def test_grep_tool_literal_mode(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "file.py").write_text("price = (a+b)\nresult = a+b\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
# literal=True should treat (a+b) as a plain string, not a regex group
result = grep_tool.func(
runtime=runtime,
description="literal search",
pattern="(a+b)",
path="/mnt/user-data/workspace",
literal=True,
)
assert "price = (a+b)" in result
assert "result = a+b" not in result
def test_grep_tool_case_sensitive(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "file.py").write_text("TODO: fix\ntodo: also fix\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = grep_tool.func(
runtime=runtime,
description="case sensitive search",
pattern="TODO",
path="/mnt/user-data/workspace",
case_sensitive=True,
)
assert "TODO: fix" in result
assert "todo: also fix" not in result
def test_grep_tool_invalid_regex_returns_error(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
result = grep_tool.func(
runtime=runtime,
description="bad pattern",
pattern="[invalid",
path="/mnt/user-data/workspace",
)
assert "Invalid regex pattern" in result
def test_aio_sandbox_glob_include_dirs_filters_nested_ignored(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(name="src", path="/mnt/workspace/src"),
SimpleNamespace(name="node_modules", path="/mnt/workspace/node_modules"),
# child of node_modules — should be filtered via should_ignore_path
SimpleNamespace(name="lib", path="/mnt/workspace/node_modules/lib"),
]
)
),
)
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
assert "/mnt/workspace/src" in matches
assert "/mnt/workspace/node_modules" not in matches
assert "/mnt/workspace/node_modules/lib" not in matches
assert truncated is False
def test_aio_sandbox_grep_invalid_regex_raises() -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
import re
try:
sandbox.grep("/mnt/workspace", "[invalid")
assert False, "Expected re.error"
except re.error:
pass
def test_aio_sandbox_glob_parses_json(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"find_files",
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(files=["/mnt/user-data/workspace/app.py", "/mnt/user-data/workspace/node_modules/skip.py"])),
)
matches, truncated = sandbox.glob("/mnt/user-data/workspace", "**/*.py")
assert matches == ["/mnt/user-data/workspace/app.py"]
assert truncated is False
def test_aio_sandbox_grep_parses_json(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(
name="app.py",
path="/mnt/user-data/workspace/app.py",
is_directory=False,
)
]
)
),
)
monkeypatch.setattr(
sandbox._client.file,
"search_in_file",
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True"])),
)
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
assert truncated is False
def test_find_glob_matches_raises_not_a_directory(tmp_path) -> None:
file_path = tmp_path / "file.txt"
file_path.write_text("x\n", encoding="utf-8")
try:
find_glob_matches(file_path, "**/*.py")
assert False, "Expected NotADirectoryError"
except NotADirectoryError:
pass
def test_find_grep_matches_raises_not_a_directory(tmp_path) -> None:
file_path = tmp_path / "file.txt"
file_path.write_text("TODO\n", encoding="utf-8")
try:
find_grep_matches(file_path, "TODO")
assert False, "Expected NotADirectoryError"
except NotADirectoryError:
pass
def test_find_grep_matches_skips_symlink_outside_root(tmp_path) -> None:
workspace = tmp_path / "workspace"
workspace.mkdir()
outside = tmp_path / "outside.txt"
outside.write_text("TODO outside\n", encoding="utf-8")
(workspace / "outside-link.txt").symlink_to(outside)
matches, truncated = find_grep_matches(workspace, "TODO")
assert matches == []
assert truncated is False
def test_glob_tool_honors_smaller_requested_max_results(tmp_path, monkeypatch) -> None:
runtime = _make_runtime(tmp_path)
workspace = tmp_path / "workspace"
(workspace / "a.py").write_text("print('a')\n", encoding="utf-8")
(workspace / "b.py").write_text("print('b')\n", encoding="utf-8")
(workspace / "c.py").write_text("print('c')\n", encoding="utf-8")
monkeypatch.setattr("deerflow.sandbox.tools.ensure_sandbox_initialized", lambda runtime: LocalSandbox(id="local"))
monkeypatch.setattr(
"deerflow.sandbox.tools.get_app_config",
lambda: SimpleNamespace(get_tool_config=lambda name: SimpleNamespace(model_extra={"max_results": 50})),
)
result = glob_tool.func(
runtime=runtime,
description="limit glob matches",
pattern="**/*.py",
path="/mnt/user-data/workspace",
max_results=2,
)
assert "Found 2 paths under /mnt/user-data/workspace (showing first 2)" in result
assert "Results truncated." in result
def test_aio_sandbox_glob_include_dirs_enforces_root_boundary(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(name="src", path="/mnt/workspace/src"),
SimpleNamespace(name="src2", path="/mnt/workspace2/src2"),
]
)
),
)
matches, truncated = sandbox.glob("/mnt/workspace", "**", include_dirs=True)
assert matches == ["/mnt/workspace/src"]
assert truncated is False
def test_aio_sandbox_grep_skips_mismatched_line_number_payloads(monkeypatch) -> None:
with patch("deerflow.community.aio_sandbox.aio_sandbox.AioSandboxClient"):
sandbox = AioSandbox(id="test-sandbox", base_url="http://localhost:8080")
monkeypatch.setattr(
sandbox._client.file,
"list_path",
lambda **kwargs: SimpleNamespace(
data=SimpleNamespace(
files=[
SimpleNamespace(
name="app.py",
path="/mnt/user-data/workspace/app.py",
is_directory=False,
)
]
)
),
)
monkeypatch.setattr(
sandbox._client.file,
"search_in_file",
lambda **kwargs: SimpleNamespace(data=SimpleNamespace(line_numbers=[7], matches=["TODO = True", "extra"])),
)
matches, truncated = sandbox.grep("/mnt/user-data/workspace", "TODO")
assert matches == [GrepMatch(path="/mnt/user-data/workspace/app.py", line_number=7, line="TODO = True")]
assert truncated is False
@@ -8,7 +8,10 @@ import pytest
from deerflow.sandbox.tools import (
VIRTUAL_PATH_PREFIX,
_apply_cwd_prefix,
_get_custom_mount_for_path,
_get_custom_mounts,
_is_acp_workspace_path,
_is_custom_mount_path,
_is_skills_path,
_reject_path_traversal,
_resolve_acp_workspace_path,
@@ -39,6 +42,53 @@ def test_replace_virtual_path_maps_virtual_root_and_subpaths() -> None:
assert Path(replace_virtual_path("/mnt/user-data", _THREAD_DATA)).as_posix() == "/tmp/deer-flow/threads/t1/user-data"
def test_replace_virtual_path_preserves_trailing_slash() -> None:
"""Trailing slash must survive virtual-to-actual path replacement.
Regression: '/mnt/user-data/workspace/' was previously returned without
the trailing slash, causing string concatenations like
output_dir + 'file.txt' to produce a missing-separator path.
"""
result = replace_virtual_path("/mnt/user-data/workspace/", _THREAD_DATA)
assert result.endswith("/"), f"Expected trailing slash, got: {result!r}"
assert result == "/tmp/deer-flow/threads/t1/user-data/workspace/"
def test_replace_virtual_path_preserves_trailing_slash_windows_style() -> None:
"""Trailing slash must be preserved as backslash when actual_base is Windows-style.
If actual_base uses backslash separators, appending '/' would produce a
mixed-separator path. The separator must match the style of actual_base.
"""
win_thread_data = {
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
}
result = replace_virtual_path("/mnt/user-data/workspace/", win_thread_data)
assert result.endswith("\\"), f"Expected trailing backslash for Windows path, got: {result!r}"
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
def test_replace_virtual_path_preserves_windows_style_for_nested_subdir_trailing_slash() -> None:
"""Nested Windows-style subdirectories must keep backslashes throughout."""
win_thread_data = {
"workspace_path": r"C:\deer-flow\threads\t1\user-data\workspace",
"uploads_path": r"C:\deer-flow\threads\t1\user-data\uploads",
"outputs_path": r"C:\deer-flow\threads\t1\user-data\outputs",
}
result = replace_virtual_path("/mnt/user-data/workspace/subdir/", win_thread_data)
assert result == "C:\\deer-flow\\threads\\t1\\user-data\\workspace\\subdir\\"
assert "/" not in result, f"Mixed separators in Windows path: {result!r}"
def test_replace_virtual_paths_in_command_preserves_trailing_slash() -> None:
"""Trailing slash on a virtual path inside a command must be preserved."""
cmd = """python -c "output_dir = '/mnt/user-data/workspace/'; print(output_dir + 'some_file.txt')\""""
result = replace_virtual_paths_in_command(cmd, _THREAD_DATA)
assert "/tmp/deer-flow/threads/t1/user-data/workspace/" in result, f"Trailing slash lost in: {result!r}"
# ---------- mask_local_paths_in_output ----------
@@ -96,6 +146,25 @@ def test_validate_local_tool_path_rejects_non_virtual_path() -> None:
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
def test_validate_local_tool_path_rejects_non_virtual_path_mentions_configured_mounts() -> None:
with pytest.raises(PermissionError, match="configured mount paths"):
validate_local_tool_path("/Users/someone/config.yaml", _THREAD_DATA)
def test_validate_local_tool_path_prioritizes_user_data_before_custom_mounts() -> None:
from deerflow.config.sandbox_config import VolumeMountConfig
mounts = [
VolumeMountConfig(host_path="/tmp/host-user-data", container_path=VIRTUAL_PATH_PREFIX, read_only=False),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/file.txt", _THREAD_DATA, read_only=True)
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path(f"{VIRTUAL_PATH_PREFIX}/workspace/../../etc/passwd", _THREAD_DATA, read_only=True)
def test_validate_local_tool_path_rejects_bare_virtual_root() -> None:
"""The bare /mnt/user-data root without trailing slash is not a valid sub-path."""
with pytest.raises(PermissionError, match="Only paths under"):
@@ -235,6 +304,22 @@ def test_validate_local_bash_command_paths_blocks_host_paths() -> None:
validate_local_bash_command_paths("cat /etc/passwd", _THREAD_DATA)
def test_validate_local_bash_command_paths_allows_https_urls() -> None:
"""URLs like https://github.com/... must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"cd /mnt/user-data/workspace && git clone https://github.com/CherryHQ/cherry-studio.git",
_THREAD_DATA,
)
def test_validate_local_bash_command_paths_allows_http_urls() -> None:
"""HTTP URLs must not be flagged as unsafe absolute paths."""
validate_local_bash_command_paths(
"curl http://example.com/file.tar.gz -o /mnt/user-data/workspace/file.tar.gz",
_THREAD_DATA,
)
def test_validate_local_bash_command_paths_allows_virtual_and_system_paths() -> None:
validate_local_bash_command_paths(
"/bin/echo ok > /mnt/user-data/workspace/out.txt && cat /dev/null",
@@ -567,6 +652,156 @@ def test_validate_local_bash_command_paths_allows_mcp_filesystem_paths() -> None
validate_local_bash_command_paths("ls /mnt/d/workspace", _THREAD_DATA)
# ---------- Custom mount path tests ----------
def _mock_custom_mounts():
"""Create mock VolumeMountConfig objects for testing."""
from deerflow.config.sandbox_config import VolumeMountConfig
return [
VolumeMountConfig(host_path="/home/user/code-read", container_path="/mnt/code-read", read_only=True),
VolumeMountConfig(host_path="/home/user/data", container_path="/mnt/data", read_only=False),
]
def test_is_custom_mount_path_recognises_configured_mounts() -> None:
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
assert _is_custom_mount_path("/mnt/code-read") is True
assert _is_custom_mount_path("/mnt/code-read/src/main.py") is True
assert _is_custom_mount_path("/mnt/data") is True
assert _is_custom_mount_path("/mnt/data/file.txt") is True
assert _is_custom_mount_path("/mnt/code-read-extra/foo") is False
assert _is_custom_mount_path("/mnt/other") is False
def test_get_custom_mount_for_path_returns_longest_prefix() -> None:
from deerflow.config.sandbox_config import VolumeMountConfig
mounts = [
VolumeMountConfig(host_path="/var/mnt", container_path="/mnt", read_only=False),
VolumeMountConfig(host_path="/home/user/code", container_path="/mnt/code", read_only=True),
]
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=mounts):
mount = _get_custom_mount_for_path("/mnt/code/file.py")
assert mount is not None
assert mount.container_path == "/mnt/code"
def test_validate_local_tool_path_allows_custom_mount_read() -> None:
"""read_file / ls should be able to access custom mount paths."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=True)
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=True)
def test_validate_local_tool_path_blocks_read_only_mount_write() -> None:
"""write_file / str_replace must NOT write to read-only custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Write access to read-only mount is not allowed"):
validate_local_tool_path("/mnt/code-read/src/main.py", _THREAD_DATA, read_only=False)
def test_validate_local_tool_path_allows_writable_mount_write() -> None:
"""write_file / str_replace should succeed on writable custom mounts."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_tool_path("/mnt/data/file.txt", _THREAD_DATA, read_only=False)
def test_validate_local_tool_path_blocks_traversal_in_custom_mount() -> None:
"""Path traversal via .. in custom mount paths must be rejected."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_tool_path("/mnt/code-read/../../etc/passwd", _THREAD_DATA, read_only=True)
def test_validate_local_bash_command_paths_allows_custom_mount() -> None:
"""bash commands referencing custom mount paths should be allowed."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
validate_local_bash_command_paths("cat /mnt/code-read/src/main.py", _THREAD_DATA)
validate_local_bash_command_paths("ls /mnt/data", _THREAD_DATA)
def test_validate_local_bash_command_paths_blocks_traversal_in_custom_mount() -> None:
"""Bash commands with traversal in custom mount paths should be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="path traversal"):
validate_local_bash_command_paths("cat /mnt/code-read/../../etc/passwd", _THREAD_DATA)
def test_validate_local_bash_command_paths_still_blocks_non_mount_paths() -> None:
"""Paths not matching any custom mount should still be blocked."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
with pytest.raises(PermissionError, match="Unsafe absolute paths"):
validate_local_bash_command_paths("cat /etc/shadow", _THREAD_DATA)
def test_get_custom_mounts_caching(monkeypatch, tmp_path) -> None:
"""_get_custom_mounts should cache after first successful load."""
# Clear any existing cache
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
# Use real directories so host_path.exists() filtering passes
dir_a = tmp_path / "code-read"
dir_a.mkdir()
dir_b = tmp_path / "data"
dir_b.mkdir()
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
mounts = [
VolumeMountConfig(host_path=str(dir_a), container_path="/mnt/code-read", read_only=True),
VolumeMountConfig(host_path=str(dir_b), container_path="/mnt/data", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 2
# After caching, should return cached value even without mock
assert hasattr(_get_custom_mounts, "_cached")
assert len(_get_custom_mounts()) == 2
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
def test_get_custom_mounts_filters_nonexistent_host_path(monkeypatch, tmp_path) -> None:
"""_get_custom_mounts should only return mounts whose host_path exists."""
if hasattr(_get_custom_mounts, "_cached"):
monkeypatch.delattr(_get_custom_mounts, "_cached")
from deerflow.config.sandbox_config import SandboxConfig, VolumeMountConfig
existing_dir = tmp_path / "existing"
existing_dir.mkdir()
mounts = [
VolumeMountConfig(host_path=str(existing_dir), container_path="/mnt/existing", read_only=True),
VolumeMountConfig(host_path="/nonexistent/path/12345", container_path="/mnt/ghost", read_only=False),
]
mock_sandbox = SandboxConfig(use="deerflow.sandbox.local:LocalSandboxProvider", mounts=mounts)
mock_config = SimpleNamespace(sandbox=mock_sandbox)
with patch("deerflow.config.get_app_config", return_value=mock_config):
result = _get_custom_mounts()
assert len(result) == 1
assert result[0].container_path == "/mnt/existing"
# Cleanup
monkeypatch.delattr(_get_custom_mounts, "_cached")
def test_get_custom_mount_for_path_boundary_no_false_prefix_match() -> None:
"""_get_custom_mount_for_path must not match /mnt/code-read-extra for /mnt/code-read."""
with patch("deerflow.sandbox.tools._get_custom_mounts", return_value=_mock_custom_mounts()):
mount = _get_custom_mount_for_path("/mnt/code-read-extra/foo")
assert mount is None
def test_str_replace_parallel_updates_should_preserve_both_edits(monkeypatch) -> None:
class SharedSandbox:
def __init__(self) -> None:
+187
View File
@@ -140,6 +140,193 @@ async def test_event_id_format(bridge: MemoryStreamBridge):
assert re.match(r"^\d+-\d+$", event.id), f"Expected timestamp-seq format, got {event.id}"
# ---------------------------------------------------------------------------
# END sentinel guarantee tests
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_end_sentinel_delivered_when_queue_full():
"""END sentinel must always be delivered, even when the queue is completely full.
This is the critical regression test for the bug where publish_end()
would silently drop the END sentinel when the queue was full, causing
subscribe() to hang forever and leaking resources.
"""
bridge = MemoryStreamBridge(queue_maxsize=2)
run_id = "run-end-full"
# Fill the queue to capacity
await bridge.publish(run_id, "event-1", {"n": 1})
await bridge.publish(run_id, "event-2", {"n": 2})
assert bridge._queues[run_id].full()
# publish_end should succeed by evicting old events
await bridge.publish_end(run_id)
# Subscriber must receive END_SENTINEL
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
assert any(e is END_SENTINEL for e in events), "END sentinel was not delivered"
@pytest.mark.anyio
async def test_end_sentinel_evicts_oldest_events():
"""When queue is full, publish_end evicts the oldest events to make room."""
bridge = MemoryStreamBridge(queue_maxsize=1)
run_id = "run-evict"
# Fill queue with one event
await bridge.publish(run_id, "will-be-evicted", {})
assert bridge._queues[run_id].full()
# publish_end must succeed
await bridge.publish_end(run_id)
# The only event we should get is END_SENTINEL (the regular event was evicted)
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
assert len(events) == 1
assert events[0] is END_SENTINEL
@pytest.mark.anyio
async def test_end_sentinel_no_eviction_when_space_available():
"""When queue has space, publish_end should not evict anything."""
bridge = MemoryStreamBridge(queue_maxsize=10)
run_id = "run-no-evict"
await bridge.publish(run_id, "event-1", {"n": 1})
await bridge.publish(run_id, "event-2", {"n": 2})
await bridge.publish_end(run_id)
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
break
# All events plus END should be present
assert len(events) == 3
assert events[0].event == "event-1"
assert events[1].event == "event-2"
assert events[2] is END_SENTINEL
@pytest.mark.anyio
async def test_concurrent_tasks_end_sentinel():
"""Multiple concurrent producer/consumer pairs should all terminate properly.
Simulates the production scenario where multiple runs share a single
bridge instance — each must receive its own END sentinel.
"""
bridge = MemoryStreamBridge(queue_maxsize=4)
num_runs = 4
async def producer(run_id: str):
for i in range(10): # More events than queue capacity
await bridge.publish(run_id, f"event-{i}", {"i": i})
await bridge.publish_end(run_id)
async def consumer(run_id: str) -> list:
events = []
async for entry in bridge.subscribe(run_id, heartbeat_interval=0.1):
events.append(entry)
if entry is END_SENTINEL:
return events
return events # pragma: no cover
# Run producers and consumers concurrently
run_ids = [f"concurrent-{i}" for i in range(num_runs)]
producers = [producer(rid) for rid in run_ids]
consumers = [consumer(rid) for rid in run_ids]
# Start consumers first, then producers
consumer_tasks = [asyncio.create_task(c) for c in consumers]
await asyncio.gather(*producers)
results = await asyncio.wait_for(
asyncio.gather(*consumer_tasks),
timeout=10.0,
)
for i, events in enumerate(results):
assert events[-1] is END_SENTINEL, f"Run {run_ids[i]} did not receive END sentinel"
# ---------------------------------------------------------------------------
# Drop counter tests
# ---------------------------------------------------------------------------
@pytest.mark.anyio
async def test_dropped_count_tracking():
"""Dropped events should be tracked per run_id."""
bridge = MemoryStreamBridge(queue_maxsize=1)
run_id = "run-drop-count"
# Fill the queue
await bridge.publish(run_id, "first", {})
# This publish will time out and be dropped (we patch timeout to be instant)
# Instead, we verify the counter after publish_end eviction
await bridge.publish_end(run_id)
# dropped_count tracks publish() drops, not publish_end evictions
assert bridge.dropped_count(run_id) == 0
# cleanup should also clear the counter
await bridge.cleanup(run_id)
assert bridge.dropped_count(run_id) == 0
@pytest.mark.anyio
async def test_dropped_total():
"""dropped_total should sum across all runs."""
bridge = MemoryStreamBridge(queue_maxsize=256)
# No drops yet
assert bridge.dropped_total == 0
# Manually set some counts to verify the property
bridge._dropped_counts["run-a"] = 3
bridge._dropped_counts["run-b"] = 7
assert bridge.dropped_total == 10
@pytest.mark.anyio
async def test_cleanup_clears_dropped_counts():
"""cleanup() should clear the dropped counter for the run."""
bridge = MemoryStreamBridge(queue_maxsize=256)
run_id = "run-cleanup-drops"
bridge._get_or_create_queue(run_id)
bridge._dropped_counts[run_id] = 5
await bridge.cleanup(run_id)
assert run_id not in bridge._dropped_counts
@pytest.mark.anyio
async def test_close_clears_dropped_counts():
"""close() should clear all dropped counters."""
bridge = MemoryStreamBridge(queue_maxsize=256)
bridge._dropped_counts["run-x"] = 10
bridge._dropped_counts["run-y"] = 20
await bridge.close()
assert bridge.dropped_total == 0
assert len(bridge._dropped_counts) == 0
# ---------------------------------------------------------------------------
# Factory tests
# ---------------------------------------------------------------------------
+94 -34
View File
@@ -1,8 +1,8 @@
"""Tests for subagent timeout configuration.
"""Tests for subagent runtime configuration.
Covers:
- SubagentsAppConfig / SubagentOverrideConfig model validation and defaults
- get_timeout_for() resolution logic (global vs per-agent)
- get_timeout_for() / get_max_turns_for() resolution logic
- load_subagents_config_from_dict() and get_subagents_app_config() singleton
- registry.get_subagent_config() applies config overrides
- registry.list_subagents() applies overrides for all agents
@@ -24,9 +24,20 @@ from deerflow.subagents.config import SubagentConfig
# ---------------------------------------------------------------------------
def _reset_subagents_config(timeout_seconds: int = 900, agents: dict | None = None) -> None:
def _reset_subagents_config(
timeout_seconds: int = 900,
*,
max_turns: int | None = None,
agents: dict | None = None,
) -> None:
"""Reset global subagents config to a known state."""
load_subagents_config_from_dict({"timeout_seconds": timeout_seconds, "agents": agents or {}})
load_subagents_config_from_dict(
{
"timeout_seconds": timeout_seconds,
"max_turns": max_turns,
"agents": agents or {},
}
)
# ---------------------------------------------------------------------------
@@ -38,22 +49,29 @@ class TestSubagentOverrideConfig:
def test_default_is_none(self):
override = SubagentOverrideConfig()
assert override.timeout_seconds is None
assert override.max_turns is None
def test_explicit_value(self):
override = SubagentOverrideConfig(timeout_seconds=300)
override = SubagentOverrideConfig(timeout_seconds=300, max_turns=42)
assert override.timeout_seconds == 300
assert override.max_turns == 42
def test_rejects_zero(self):
with pytest.raises(ValueError):
SubagentOverrideConfig(timeout_seconds=0)
with pytest.raises(ValueError):
SubagentOverrideConfig(max_turns=0)
def test_rejects_negative(self):
with pytest.raises(ValueError):
SubagentOverrideConfig(timeout_seconds=-1)
with pytest.raises(ValueError):
SubagentOverrideConfig(max_turns=-1)
def test_minimum_valid_value(self):
override = SubagentOverrideConfig(timeout_seconds=1)
override = SubagentOverrideConfig(timeout_seconds=1, max_turns=1)
assert override.timeout_seconds == 1
assert override.max_turns == 1
# ---------------------------------------------------------------------------
@@ -66,66 +84,86 @@ class TestSubagentsAppConfigDefaults:
config = SubagentsAppConfig()
assert config.timeout_seconds == 900
def test_default_max_turns_override_is_none(self):
config = SubagentsAppConfig()
assert config.max_turns is None
def test_default_agents_empty(self):
config = SubagentsAppConfig()
assert config.agents == {}
def test_custom_global_timeout(self):
config = SubagentsAppConfig(timeout_seconds=1800)
def test_custom_global_runtime_overrides(self):
config = SubagentsAppConfig(timeout_seconds=1800, max_turns=120)
assert config.timeout_seconds == 1800
assert config.max_turns == 120
def test_rejects_zero_timeout(self):
with pytest.raises(ValueError):
SubagentsAppConfig(timeout_seconds=0)
with pytest.raises(ValueError):
SubagentsAppConfig(max_turns=0)
def test_rejects_negative_timeout(self):
with pytest.raises(ValueError):
SubagentsAppConfig(timeout_seconds=-60)
with pytest.raises(ValueError):
SubagentsAppConfig(max_turns=-60)
# ---------------------------------------------------------------------------
# SubagentsAppConfig.get_timeout_for()
# SubagentsAppConfig resolution helpers
# ---------------------------------------------------------------------------
class TestGetTimeoutFor:
class TestRuntimeResolution:
def test_returns_global_default_when_no_override(self):
config = SubagentsAppConfig(timeout_seconds=600)
assert config.get_timeout_for("general-purpose") == 600
assert config.get_timeout_for("bash") == 600
assert config.get_timeout_for("unknown-agent") == 600
assert config.get_max_turns_for("general-purpose", 100) == 100
assert config.get_max_turns_for("bash", 60) == 60
def test_returns_per_agent_override_when_set(self):
config = SubagentsAppConfig(
timeout_seconds=900,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300)},
max_turns=120,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
)
assert config.get_timeout_for("bash") == 300
assert config.get_max_turns_for("bash", 60) == 80
def test_other_agents_still_use_global_default(self):
config = SubagentsAppConfig(
timeout_seconds=900,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300)},
max_turns=140,
agents={"bash": SubagentOverrideConfig(timeout_seconds=300, max_turns=80)},
)
assert config.get_timeout_for("general-purpose") == 900
assert config.get_max_turns_for("general-purpose", 100) == 140
def test_agent_with_none_override_falls_back_to_global(self):
config = SubagentsAppConfig(
timeout_seconds=900,
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None)},
max_turns=150,
agents={"general-purpose": SubagentOverrideConfig(timeout_seconds=None, max_turns=None)},
)
assert config.get_timeout_for("general-purpose") == 900
assert config.get_max_turns_for("general-purpose", 100) == 150
def test_multiple_per_agent_overrides(self):
config = SubagentsAppConfig(
timeout_seconds=900,
max_turns=120,
agents={
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800),
"bash": SubagentOverrideConfig(timeout_seconds=120),
"general-purpose": SubagentOverrideConfig(timeout_seconds=1800, max_turns=200),
"bash": SubagentOverrideConfig(timeout_seconds=120, max_turns=80),
},
)
assert config.get_timeout_for("general-purpose") == 1800
assert config.get_timeout_for("bash") == 120
assert config.get_max_turns_for("general-purpose", 100) == 200
assert config.get_max_turns_for("bash", 60) == 80
# ---------------------------------------------------------------------------
@@ -139,54 +177,63 @@ class TestLoadSubagentsConfig:
_reset_subagents_config()
def test_load_global_timeout(self):
load_subagents_config_from_dict({"timeout_seconds": 300})
load_subagents_config_from_dict({"timeout_seconds": 300, "max_turns": 120})
assert get_subagents_app_config().timeout_seconds == 300
assert get_subagents_app_config().max_turns == 120
def test_load_with_per_agent_overrides(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {
"general-purpose": {"timeout_seconds": 1800},
"bash": {"timeout_seconds": 60},
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
}
)
cfg = get_subagents_app_config()
assert cfg.get_timeout_for("general-purpose") == 1800
assert cfg.get_timeout_for("bash") == 60
assert cfg.get_max_turns_for("general-purpose", 100) == 200
assert cfg.get_max_turns_for("bash", 60) == 80
def test_load_partial_override(self):
load_subagents_config_from_dict(
{
"timeout_seconds": 600,
"agents": {"bash": {"timeout_seconds": 120}},
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 70}},
}
)
cfg = get_subagents_app_config()
assert cfg.get_timeout_for("general-purpose") == 600
assert cfg.get_timeout_for("bash") == 120
assert cfg.get_max_turns_for("general-purpose", 100) == 100
assert cfg.get_max_turns_for("bash", 60) == 70
def test_load_empty_dict_uses_defaults(self):
load_subagents_config_from_dict({})
cfg = get_subagents_app_config()
assert cfg.timeout_seconds == 900
assert cfg.max_turns is None
assert cfg.agents == {}
def test_load_replaces_previous_config(self):
load_subagents_config_from_dict({"timeout_seconds": 100})
load_subagents_config_from_dict({"timeout_seconds": 100, "max_turns": 90})
assert get_subagents_app_config().timeout_seconds == 100
assert get_subagents_app_config().max_turns == 90
load_subagents_config_from_dict({"timeout_seconds": 200})
load_subagents_config_from_dict({"timeout_seconds": 200, "max_turns": 110})
assert get_subagents_app_config().timeout_seconds == 200
assert get_subagents_app_config().max_turns == 110
def test_singleton_returns_same_instance_between_calls(self):
load_subagents_config_from_dict({"timeout_seconds": 777})
load_subagents_config_from_dict({"timeout_seconds": 777, "max_turns": 123})
assert get_subagents_app_config() is get_subagents_app_config()
# ---------------------------------------------------------------------------
# registry.get_subagent_config timeout override applied
# registry.get_subagent_config runtime overrides applied
# ---------------------------------------------------------------------------
@@ -211,25 +258,29 @@ class TestRegistryGetSubagentConfig:
_reset_subagents_config(timeout_seconds=900)
config = get_subagent_config("general-purpose")
assert config.timeout_seconds == 900
assert config.max_turns == 100
def test_global_timeout_override_applied(self):
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=1800)
_reset_subagents_config(timeout_seconds=1800, max_turns=140)
config = get_subagent_config("general-purpose")
assert config.timeout_seconds == 1800
assert config.max_turns == 140
def test_per_agent_timeout_override_applied(self):
def test_per_agent_runtime_override_applied(self):
from deerflow.subagents.registry import get_subagent_config
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"timeout_seconds": 120}},
"max_turns": 120,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
}
)
bash_config = get_subagent_config("bash")
assert bash_config.timeout_seconds == 120
assert bash_config.max_turns == 80
def test_per_agent_override_does_not_affect_other_agents(self):
from deerflow.subagents.registry import get_subagent_config
@@ -237,11 +288,13 @@ class TestRegistryGetSubagentConfig:
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"agents": {"bash": {"timeout_seconds": 120}},
"max_turns": 120,
"agents": {"bash": {"timeout_seconds": 120, "max_turns": 80}},
}
)
gp_config = get_subagent_config("general-purpose")
assert gp_config.timeout_seconds == 900
assert gp_config.max_turns == 120
def test_builtin_config_object_is_not_mutated(self):
"""Registry must return a new object, leaving the builtin default intact."""
@@ -249,24 +302,27 @@ class TestRegistryGetSubagentConfig:
from deerflow.subagents.registry import get_subagent_config
original_timeout = BUILTIN_SUBAGENTS["bash"].timeout_seconds
load_subagents_config_from_dict({"timeout_seconds": 42})
original_max_turns = BUILTIN_SUBAGENTS["bash"].max_turns
load_subagents_config_from_dict({"timeout_seconds": 42, "max_turns": 88})
returned = get_subagent_config("bash")
assert returned.timeout_seconds == 42
assert returned.max_turns == 88
assert BUILTIN_SUBAGENTS["bash"].timeout_seconds == original_timeout
assert BUILTIN_SUBAGENTS["bash"].max_turns == original_max_turns
def test_config_preserves_other_fields(self):
"""Applying timeout override must not change other SubagentConfig fields."""
"""Applying runtime overrides must not change other SubagentConfig fields."""
from deerflow.subagents.builtins import BUILTIN_SUBAGENTS
from deerflow.subagents.registry import get_subagent_config
_reset_subagents_config(timeout_seconds=300)
_reset_subagents_config(timeout_seconds=300, max_turns=140)
original = BUILTIN_SUBAGENTS["general-purpose"]
overridden = get_subagent_config("general-purpose")
assert overridden.name == original.name
assert overridden.description == original.description
assert overridden.max_turns == original.max_turns
assert overridden.max_turns == 140
assert overridden.model == original.model
assert overridden.tools == original.tools
assert overridden.disallowed_tools == original.disallowed_tools
@@ -291,9 +347,10 @@ class TestRegistryListSubagents:
def test_all_returned_configs_get_global_override(self):
from deerflow.subagents.registry import list_subagents
_reset_subagents_config(timeout_seconds=123)
_reset_subagents_config(timeout_seconds=123, max_turns=77)
for cfg in list_subagents():
assert cfg.timeout_seconds == 123, f"{cfg.name} has wrong timeout"
assert cfg.max_turns == 77, f"{cfg.name} has wrong max_turns"
def test_per_agent_overrides_reflected_in_list(self):
from deerflow.subagents.registry import list_subagents
@@ -301,15 +358,18 @@ class TestRegistryListSubagents:
load_subagents_config_from_dict(
{
"timeout_seconds": 900,
"max_turns": 120,
"agents": {
"general-purpose": {"timeout_seconds": 1800},
"bash": {"timeout_seconds": 60},
"general-purpose": {"timeout_seconds": 1800, "max_turns": 200},
"bash": {"timeout_seconds": 60, "max_turns": 80},
},
}
)
by_name = {cfg.name: cfg for cfg in list_subagents()}
assert by_name["general-purpose"].timeout_seconds == 1800
assert by_name["bash"].timeout_seconds == 60
assert by_name["general-purpose"].max_turns == 200
assert by_name["bash"].max_turns == 80
# ---------------------------------------------------------------------------
+5 -5
View File
@@ -1,5 +1,5 @@
import asyncio
from unittest.mock import MagicMock
from unittest.mock import AsyncMock, MagicMock
from app.gateway.routers import suggestions
@@ -43,7 +43,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```')
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='```json\n["Q1", "Q2", "Q3", "Q4"]\n```'))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@@ -61,7 +61,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}])
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "text", "text": '```json\n["Q1", "Q2"]\n```'}]))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@@ -79,7 +79,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.return_value = MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}])
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content=[{"type": "output_text", "text": '```json\n["Q1", "Q2"]\n```'}]))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
@@ -94,7 +94,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
model_name=None,
)
fake_model = MagicMock()
fake_model.invoke.side_effect = RuntimeError("boom")
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("boom"))
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
result = asyncio.run(suggestions.generate_suggestions("t1", req))
+111
View File
@@ -0,0 +1,111 @@
"""Tests for thread_runs router with auth decorators.
These tests verify that auth decorators properly enforce permission checks
on run endpoints. They follow the same pattern as test_threads_router.py.
"""
from unittest.mock import MagicMock, patch
from uuid import uuid4
from fastapi import FastAPI
from fastapi.testclient import TestClient
from app.gateway.auth.models import User
from app.gateway.authz import AuthContext
from app.gateway.routers.thread_runs import router
def test_create_run_requires_auth():
"""POST /{thread_id}/runs requires auth."""
app = FastAPI()
app.include_router(router)
with TestClient(app, raise_server_exceptions=False) as client:
response = client.post(
"/api/threads/test-thread/runs",
json={"assistant_id": "test"},
)
assert response.status_code == 401
def test_create_run_with_auth():
"""POST /{thread_id}/runs with valid auth passes through."""
app = FastAPI()
app.include_router(router)
mock_user = User(id=uuid4(), email="test@example.com", password_hash="hash")
mock_auth = AuthContext(
user=mock_user,
permissions=["runs:create", "threads:read", "threads:write"],
)
# Mock the checkpointer and run_manager to avoid 503s
mock_checkpointer = MagicMock()
mock_run_manager = MagicMock()
mock_run_manager.list_by_thread = MagicMock(return_value=[])
mock_stream_bridge = MagicMock()
with patch("app.gateway.routers.thread_runs.get_checkpointer", return_value=mock_checkpointer):
with patch("app.gateway.routers.thread_runs.get_run_manager", return_value=mock_run_manager):
with patch("app.gateway.routers.thread_runs.get_stream_bridge", return_value=mock_stream_bridge):
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
with TestClient(app, raise_server_exceptions=False) as client:
# Without a real checkpointer.setup, this will 500 - but the point is auth passed
response = client.post(
"/api/threads/test-thread/runs",
json={"assistant_id": "test"},
)
# Auth passed if we don't get 401
assert response.status_code != 401
def test_list_runs_requires_auth():
"""GET /{thread_id}/runs requires auth."""
app = FastAPI()
app.include_router(router)
with TestClient(app, raise_server_exceptions=False) as client:
response = client.get("/api/threads/test-thread/runs")
assert response.status_code == 401
def test_list_runs_with_auth():
"""GET /{thread_id}/runs with auth passes through."""
app = FastAPI()
app.include_router(router)
mock_user = User(id=uuid4(), email="test@example.com", password_hash="hash")
mock_auth = AuthContext(
user=mock_user,
permissions=["runs:read", "threads:read"],
)
mock_run_manager = MagicMock()
mock_run_manager.list_by_thread = MagicMock(return_value=[])
with patch("app.gateway.routers.thread_runs.get_run_manager", return_value=mock_run_manager):
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
with TestClient(app, raise_server_exceptions=False) as client:
response = client.get("/api/threads/test-thread/runs")
# Should not be 401 (may be 500 or other, but auth passed)
assert response.status_code != 401
def test_get_run_requires_auth():
"""GET /{thread_id}/runs/{run_id} requires auth."""
app = FastAPI()
app.include_router(router)
with TestClient(app, raise_server_exceptions=False) as client:
response = client.get("/api/threads/test-thread/runs/run-123")
assert response.status_code == 401
def test_cancel_run_requires_auth():
"""POST /{thread_id}/runs/{run_id}/cancel requires auth."""
app = FastAPI()
app.include_router(router)
with TestClient(app, raise_server_exceptions=False) as client:
response = client.post("/api/threads/test-thread/runs/run-123/cancel")
assert response.status_code == 401
+33 -7
View File
@@ -49,20 +49,43 @@ def test_delete_thread_data_rejects_invalid_thread_id(tmp_path):
def test_delete_thread_route_cleans_thread_directory(tmp_path):
"""DELETE /{thread_id} requires auth + permission — mock auth and store."""
from unittest.mock import AsyncMock, MagicMock
from app.gateway.authz import AuthContext
tid = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee"
paths = Paths(tmp_path)
thread_dir = paths.thread_dir("thread-route")
paths.sandbox_work_dir("thread-route").mkdir(parents=True, exist_ok=True)
(paths.sandbox_work_dir("thread-route") / "notes.txt").write_text("hello", encoding="utf-8")
thread_dir = paths.thread_dir(tid)
paths.sandbox_work_dir(tid).mkdir(parents=True, exist_ok=True)
(paths.sandbox_work_dir(tid) / "notes.txt").write_text("hello", encoding="utf-8")
# Mock store item with .value attribute
mock_store = MagicMock()
mock_record = {
"thread_id": tid,
"metadata": {"user_id": "test-user-123"},
}
mock_store_item = MagicMock()
mock_store_item.value = mock_record
mock_store.aget = AsyncMock(return_value=mock_store_item)
mock_user = MagicMock()
mock_user.id = "test-user-123"
mock_auth = AuthContext(user=mock_user, permissions=["threads:read", "threads:write", "threads:delete"])
app = FastAPI()
app.include_router(threads.router)
with patch("app.gateway.routers.threads.get_paths", return_value=paths):
with TestClient(app) as client:
response = client.delete("/api/threads/thread-route")
with patch("app.gateway.routers.threads.get_store", return_value=mock_store):
with patch("app.gateway.routers.threads.get_checkpointer", return_value=MagicMock()):
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
with TestClient(app) as client:
response = client.delete(f"/api/threads/{tid}")
assert response.status_code == 200
assert response.json() == {"success": True, "message": "Deleted local thread data for thread-route"}
assert response.json() == {"success": True, "message": f"Deleted local thread data for {tid}"}
assert not thread_dir.exists()
@@ -80,6 +103,7 @@ def test_delete_thread_route_rejects_invalid_thread_id(tmp_path):
def test_delete_thread_route_returns_422_for_route_safe_invalid_id(tmp_path):
"""DELETE /{thread_id} with non-UUID id — FastAPI rejects at path validation."""
paths = Paths(tmp_path)
app = FastAPI()
@@ -90,7 +114,9 @@ def test_delete_thread_route_returns_422_for_route_safe_invalid_id(tmp_path):
response = client.delete("/api/threads/thread.with.dot")
assert response.status_code == 422
assert "Invalid thread_id" in response.json()["detail"]
# FastAPI returns a list of validation errors for path parameter mismatch
detail = response.json()["detail"]
assert any("thread_id" in str(err) for err in detail)
def test_delete_thread_data_returns_generic_500_error(tmp_path):
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.agents.middlewares import title_middleware as title_middleware_module
from deerflow.agents.middlewares.title_middleware import TitleMiddleware
from deerflow.config.title_config import TitleConfig, get_title_config, set_title_config
@@ -73,37 +74,32 @@ class TestTitleMiddlewareCoreLogic:
assert middleware._should_generate_title(state) is False
def test_generate_title_trims_quotes_and_respects_max_chars(self, monkeypatch):
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
_set_test_title_config(max_chars=12)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.ainvoke = AsyncMock(return_value=MagicMock(content='"A very long generated title"'))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
HumanMessage(content="请帮我写一个脚本"),
HumanMessage(content="请帮我写一个很长很长的脚本标题"),
AIMessage(content="好的,先确认需求"),
]
}
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
assert '"' not in title
assert "'" not in title
assert len(title) == 12
assert title == "短标题"
title_middleware_module.create_chat_model.assert_called_once_with(thinking_enabled=False)
model.ainvoke.assert_awaited_once()
def test_generate_title_normalizes_structured_message_and_response_content(self, monkeypatch):
def test_generate_title_normalizes_structured_message_content(self, monkeypatch):
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.ainvoke = AsyncMock(
return_value=MagicMock(content=[{"type": "text", "text": '"结构总结"'}]),
)
monkeypatch.setattr(
"deerflow.agents.middlewares.title_middleware.create_chat_model",
lambda **kwargs: fake_model,
)
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="请帮我总结这段代码"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
@@ -115,21 +111,14 @@ class TestTitleMiddlewareCoreLogic:
result = asyncio.run(middleware._agenerate_title_result(state))
title = result["title"]
prompt = fake_model.ainvoke.await_args.args[0]
assert "请帮我总结这段代码" in prompt
assert "好的,先看结构" in prompt
# Ensure structured message dict/JSON reprs are not leaking into the prompt.
assert "{'type':" not in prompt
assert "'type':" not in prompt
assert '"type":' not in prompt
assert title == "结构总结"
assert title == "请帮我总结这段代码"
def test_generate_title_fallback_when_model_fails(self, monkeypatch):
def test_generate_title_fallback_for_long_message(self, monkeypatch):
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.ainvoke = AsyncMock(side_effect=RuntimeError("LLM unavailable"))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
model = MagicMock()
model.ainvoke = AsyncMock(side_effect=RuntimeError("model unavailable"))
monkeypatch.setattr(title_middleware_module, "create_chat_model", MagicMock(return_value=model))
state = {
"messages": [
@@ -164,13 +153,10 @@ class TestTitleMiddlewareCoreLogic:
monkeypatch.setattr(middleware, "_generate_title_result", MagicMock(return_value=None))
assert middleware.after_model({"messages": []}, runtime=MagicMock()) is None
def test_sync_generate_title_with_model(self, monkeypatch):
"""Sync path calls model.invoke and produces a title."""
def test_sync_generate_title_uses_fallback_without_model(self):
"""Sync path avoids LLM calls and derives a local fallback title."""
_set_test_title_config(max_chars=20)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content='"同步生成的标题"'))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
@@ -179,22 +165,19 @@ class TestTitleMiddlewareCoreLogic:
]
}
result = middleware._generate_title_result(state)
assert result == {"title": "同步生成的标题"}
fake_model.invoke.assert_called_once()
assert result == {"title": "请帮我写测试"}
def test_empty_title_falls_back(self, monkeypatch):
"""Empty model response triggers fallback title."""
def test_sync_generate_title_respects_fallback_truncation(self):
"""Sync fallback path still respects max_chars truncation rules."""
_set_test_title_config(max_chars=50)
middleware = TitleMiddleware()
fake_model = MagicMock()
fake_model.invoke = MagicMock(return_value=MagicMock(content=" "))
monkeypatch.setattr("deerflow.agents.middlewares.title_middleware.create_chat_model", lambda **kwargs: fake_model)
state = {
"messages": [
HumanMessage(content="空标题测试"),
HumanMessage(content="这是一个非常长的问题描述,需要被截断以形成fallback标题,而且这里继续补充更多上下文,确保超过本地fallback截断阈值"),
AIMessage(content="回复"),
]
}
result = middleware._generate_title_result(state)
assert result["title"] == "空标题测试"
assert result["title"].endswith("...")
assert result["title"].startswith("这是一个非常长的问题描述")
@@ -289,6 +289,8 @@ class TestBeforeAgent:
"size": 5,
"path": "/mnt/user-data/uploads/notes.txt",
"extension": ".txt",
"outline": [],
"outline_preview": [],
}
]
@@ -339,3 +341,130 @@ class TestBeforeAgent:
result = mw.before_agent(self._state(msg), _runtime())
assert result["messages"][-1].id == "original-id-42"
def test_outline_injected_when_md_file_exists(self, tmp_path):
"""When a converted .md file exists alongside the upload, its outline is injected."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "report.pdf").write_bytes(b"%PDF fake")
# Simulate the .md produced by the conversion pipeline
(uploads_dir / "report.md").write_text(
"# PART I\n\n## ITEM 1. BUSINESS\n\nBody text.\n\n## ITEM 2. RISK\n",
encoding="utf-8",
)
msg = _human("summarise", files=[{"filename": "report.pdf", "size": 9, "path": "/mnt/user-data/uploads/report.pdf"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "Document outline" in content
assert "PART I" in content
assert "ITEM 1. BUSINESS" in content
assert "ITEM 2. RISK" in content
assert "read_file" in content
def test_no_outline_when_no_md_file(self, tmp_path):
"""Files without a sibling .md have no outline section."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "data.xlsx").write_bytes(b"fake-xlsx")
msg = _human("analyse", files=[{"filename": "data.xlsx", "size": 9, "path": "/mnt/user-data/uploads/data.xlsx"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "Document outline" not in content
def test_outline_truncation_hint_shown(self, tmp_path):
"""When outline is truncated, a hint line is appended after the last visible entry."""
from deerflow.utils.file_conversion import MAX_OUTLINE_ENTRIES
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "big.pdf").write_bytes(b"%PDF fake")
# Write MAX_OUTLINE_ENTRIES + 5 headings so truncation is triggered
headings = "\n".join(f"# Heading {i}" for i in range(MAX_OUTLINE_ENTRIES + 5))
(uploads_dir / "big.md").write_text(headings, encoding="utf-8")
msg = _human("read", files=[{"filename": "big.pdf", "size": 9, "path": "/mnt/user-data/uploads/big.pdf"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert f"showing first {MAX_OUTLINE_ENTRIES} headings" in content
assert "use `read_file` to explore further" in content
def test_no_truncation_hint_for_short_outline(self, tmp_path):
"""Short outlines (under the cap) must not show a truncation hint."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "short.pdf").write_bytes(b"%PDF fake")
(uploads_dir / "short.md").write_text("# Intro\n\n# Conclusion\n", encoding="utf-8")
msg = _human("read", files=[{"filename": "short.pdf", "size": 9, "path": "/mnt/user-data/uploads/short.pdf"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "showing first" not in content
def test_historical_file_outline_injected(self, tmp_path):
"""Outline is also shown for historical (previously uploaded) files."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
# Historical file with .md
(uploads_dir / "old_report.pdf").write_bytes(b"%PDF old")
(uploads_dir / "old_report.md").write_text(
"# Chapter 1\n\n# Chapter 2\n",
encoding="utf-8",
)
# New file without .md
(uploads_dir / "new.txt").write_bytes(b"new")
msg = _human("go", files=[{"filename": "new.txt", "size": 3, "path": "/mnt/user-data/uploads/new.txt"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "Chapter 1" in content
assert "Chapter 2" in content
def test_fallback_preview_shown_when_outline_empty(self, tmp_path):
"""When .md exists but has no headings, first lines are shown as a preview."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "report.pdf").write_bytes(b"%PDF fake")
# .md with no # headings — plain prose only
(uploads_dir / "report.md").write_text(
"Annual Financial Report 2024\n\nThis document summarises key findings.\n\nRevenue grew by 12%.\n",
encoding="utf-8",
)
msg = _human("analyse", files=[{"filename": "report.pdf", "size": 9, "path": "/mnt/user-data/uploads/report.pdf"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
# Outline section must NOT appear
assert "Document outline" not in content
# Preview lines must appear
assert "Annual Financial Report 2024" in content
assert "No structural headings detected" in content
# grep hint must appear
assert "grep" in content
def test_fallback_grep_hint_shown_when_no_md_file(self, tmp_path):
"""Files with no sibling .md still get the grep hint (outline is empty)."""
mw = _middleware(tmp_path)
uploads_dir = _uploads_dir(tmp_path)
(uploads_dir / "data.csv").write_bytes(b"a,b,c\n1,2,3\n")
msg = _human("analyse", files=[{"filename": "data.csv", "size": 12, "path": "/mnt/user-data/uploads/data.csv"}])
result = mw.before_agent(self._state(msg), _runtime())
assert result is not None
content = result["messages"][-1].content
assert "Document outline" not in content
assert "grep" in content