[security] fix(auth): reject cross-site auth POSTs (#2740)

* fix(security): reject cross-site auth posts

* fix(auth): align secure cookie proxy scheme handling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
Hinotobi
2026-05-07 07:58:06 +08:00
committed by GitHub
parent 1336872b15
commit 2b0e62f679
2 changed files with 331 additions and 1 deletions
+112 -1
View File
@@ -4,8 +4,10 @@ Per RFC-001:
State-changing operations require CSRF protection. State-changing operations require CSRF protection.
""" """
import os
import secrets import secrets
from collections.abc import Callable from collections.abc import Callable
from urllib.parse import urlsplit
from fastapi import Request, Response from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
@@ -19,7 +21,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
def is_secure_request(request: Request) -> bool: def is_secure_request(request: Request) -> bool:
"""Detect whether the original client request was made over HTTPS.""" """Detect whether the original client request was made over HTTPS."""
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https" return _request_scheme(request) == "https"
def generate_csrf_token() -> str: def generate_csrf_token() -> str:
@@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool:
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
"""Return normalized host[:port], omitting default ports."""
host = hostname.lower()
if ":" in host and not host.startswith("["):
host = f"[{host}]"
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
return host
return f"{host}:{port}"
def _normalize_origin(origin: str) -> str | None:
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
try:
parsed = urlsplit(origin.strip())
port = parsed.port
except ValueError:
return None
scheme = parsed.scheme.lower()
if scheme not in {"http", "https"} or not parsed.hostname:
return None
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
return None
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
def _configured_cors_origins() -> set[str]:
"""Return explicit configured browser origins that may call auth routes."""
origins = set()
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
origin = raw_origin.strip()
if not origin or origin == "*":
continue
normalized = _normalize_origin(origin)
if normalized:
origins.add(normalized)
return origins
def _first_header_value(value: str | None) -> str | None:
"""Return the first value from a comma-separated proxy header."""
if not value:
return None
first = value.split(",", 1)[0].strip()
return first or None
def _forwarded_param(request: Request, name: str) -> str | None:
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
forwarded = _first_header_value(request.headers.get("forwarded"))
if not forwarded:
return None
for part in forwarded.split(";"):
key, sep, value = part.strip().partition("=")
if sep and key.lower() == name:
return value.strip().strip('"') or None
return None
def _request_scheme(request: Request) -> str:
"""Resolve the original request scheme from trusted proxy headers."""
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
return scheme.lower()
def _request_origin(request: Request) -> str | None:
"""Build the origin for the URL the browser is targeting."""
scheme = _request_scheme(request)
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
host = f"{host}:{forwarded_port}"
return _normalize_origin(f"{scheme}://{host}")
def is_allowed_auth_origin(request: Request) -> bool:
"""Allow auth POSTs only from the same origin or explicit configured origins.
Login/register/initialize are exempt from the double-submit token because
first-time browser clients do not have a CSRF token yet. They still create
a session cookie, so browser requests with a hostile Origin header must be
rejected to prevent login CSRF / session fixation. Requests without Origin
are allowed for non-browser clients such as curl and mobile integrations.
"""
origin = request.headers.get("origin")
if not origin:
return True
normalized_origin = _normalize_origin(origin)
if normalized_origin is None:
return False
request_origin = _request_origin(request)
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
class CSRFMiddleware(BaseHTTPMiddleware): class CSRFMiddleware(BaseHTTPMiddleware):
"""Middleware that implements CSRF protection using Double Submit Cookie pattern.""" """Middleware that implements CSRF protection using Double Submit Cookie pattern."""
@@ -70,6 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next: Callable) -> Response: async def dispatch(self, request: Request, call_next: Callable) -> Response:
_is_auth = is_auth_endpoint(request) _is_auth = is_auth_endpoint(request)
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
return JSONResponse(
status_code=403,
content={"detail": "Cross-site auth request denied."},
)
if should_check_csrf(request) and not _is_auth: if should_check_csrf(request) and not _is_auth:
cookie_token = request.cookies.get(CSRF_COOKIE_NAME) cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
header_token = request.headers.get(CSRF_HEADER_NAME) header_token = request.headers.get(CSRF_HEADER_NAME)
+219
View File
@@ -0,0 +1,219 @@
"""Tests for CSRF middleware."""
from fastapi import FastAPI
from starlette.testclient import TestClient
from app.gateway.csrf_middleware import CSRFMiddleware
def _make_app() -> FastAPI:
app = FastAPI()
app.add_middleware(CSRFMiddleware)
@app.post("/api/v1/auth/login/local")
async def login_local():
return {"ok": True}
@app.post("/api/v1/auth/register")
async def register():
return {"ok": True}
@app.post("/api/threads/abc/runs/stream")
async def protected_mutation():
return {"ok": True}
return app
def test_auth_post_rejects_cross_origin_browser_request():
"""CSRF-exempt auth routes must not accept hostile browser origins.
Login/register endpoints intentionally skip the double-submit token because
first-time callers do not have a token yet. They still set an auth session,
so a hostile cross-site form POST must be rejected to avoid login CSRF /
session fixation.
"""
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_allows_same_origin_browser_request():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_rejects_malformed_origin_with_path():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example/path"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_rejects_malformed_origin_with_invalid_port():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:bad"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
assert response.cookies.get("csrf_token") is None
def test_auth_post_allows_same_origin_default_port_equivalence():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example:443"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"X-Forwarded-Proto": "https",
"X-Forwarded-Host": "deerflow.example, internal:8000",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_allows_rfc_forwarded_same_origin():
client = TestClient(_make_app(), base_url="http://internal:8000")
response = client.post(
"/api/v1/auth/login/local",
headers={
"Origin": "https://deerflow.example",
"Forwarded": "proto=https;host=deerflow.example",
},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
assert "secure" in response.headers["set-cookie"].lower()
def test_auth_post_allows_explicit_configured_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/register",
headers={"Origin": "https://app.example"},
)
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
client = TestClient(_make_app(), base_url="https://api.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://evil.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "Cross-site auth request denied."
def test_auth_post_sets_strict_samesite_csrf_cookie():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/v1/auth/login/local",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 200
set_cookie = response.headers["set-cookie"].lower()
assert "csrf_token=" in set_cookie
assert "samesite=strict" in set_cookie
assert "secure" in set_cookie
def test_auth_post_without_origin_still_allows_non_browser_clients():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post("/api/v1/auth/login/local")
assert response.status_code == 200
assert response.cookies.get("csrf_token")
def test_non_auth_mutation_still_requires_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
response = client.post(
"/api/threads/abc/runs/stream",
headers={"Origin": "https://deerflow.example"},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
def test_non_auth_mutation_allows_valid_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "known-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "known-token",
},
)
assert response.status_code == 200
def test_non_auth_mutation_rejects_mismatched_double_submit_token():
client = TestClient(_make_app(), base_url="https://deerflow.example")
client.cookies.set("csrf_token", "cookie-token")
response = client.post(
"/api/threads/abc/runs/stream",
headers={
"Origin": "https://deerflow.example",
"X-CSRF-Token": "header-token",
},
)
assert response.status_code == 403
assert response.json()["detail"] == "CSRF token mismatch."