"""CSRF protection middleware for FastAPI. Per RFC-001: State-changing operations require CSRF protection. """ import secrets from collections.abc import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from starlette.responses import JSONResponse from starlette.types import ASGIApp CSRF_COOKIE_NAME = "csrf_token" CSRF_HEADER_NAME = "X-CSRF-Token" CSRF_TOKEN_LENGTH = 64 # bytes def is_secure_request(request: Request) -> bool: """Detect whether the original client request was made over HTTPS.""" return request.headers.get("x-forwarded-proto", request.url.scheme) == "https" def generate_csrf_token() -> str: """Generate a secure random CSRF token.""" return secrets.token_urlsafe(CSRF_TOKEN_LENGTH) def should_check_csrf(request: Request) -> bool: """Determine if a request needs CSRF validation. CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH). GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231. """ if request.method not in ("POST", "PUT", "DELETE", "PATCH"): return False path = request.url.path.rstrip("/") # Exempt /api/v1/auth/me endpoint if path == "/api/v1/auth/me": return False return True _AUTH_EXEMPT_PATHS: frozenset[str] = frozenset( { "/api/v1/auth/login/local", "/api/v1/auth/logout", "/api/v1/auth/register", "/api/v1/auth/initialize", } ) def is_auth_endpoint(request: Request) -> bool: """Check if the request is to an auth endpoint. Auth endpoints don't need CSRF validation on first call (no token). """ return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS class CSRFMiddleware(BaseHTTPMiddleware): """Middleware that implements CSRF protection using Double Submit Cookie pattern.""" def __init__(self, app: ASGIApp) -> None: super().__init__(app) async def dispatch(self, request: Request, call_next: Callable) -> Response: _is_auth = is_auth_endpoint(request) if should_check_csrf(request) and not _is_auth: cookie_token = request.cookies.get(CSRF_COOKIE_NAME) header_token = request.headers.get(CSRF_HEADER_NAME) if not cookie_token or not header_token: return JSONResponse( status_code=403, content={"detail": "CSRF token missing. Include X-CSRF-Token header."}, ) if not secrets.compare_digest(cookie_token, header_token): return JSONResponse( status_code=403, content={"detail": "CSRF token mismatch."}, ) response = await call_next(request) # For auth endpoints that set up session, also set CSRF cookie if _is_auth and request.method == "POST": # Generate a new CSRF token for the session csrf_token = generate_csrf_token() is_https = is_secure_request(request) response.set_cookie( key=CSRF_COOKIE_NAME, value=csrf_token, httponly=False, # Must be JS-readable for Double Submit Cookie pattern secure=is_https, samesite="strict", ) return response def get_csrf_token(request: Request) -> str | None: """Get the CSRF token from the current request's cookies. This is useful for server-side rendering where you need to embed token in forms or headers. """ return request.cookies.get(CSRF_COOKIE_NAME)