mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
[security] fix(auth): reject cross-site auth POSTs (#2740)
* fix(security): reject cross-site auth posts * fix(auth): align secure cookie proxy scheme handling --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -4,8 +4,10 @@ Per RFC-001:
|
|||||||
State-changing operations require CSRF protection.
|
State-changing operations require CSRF protection.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from urllib.parse import urlsplit
|
||||||
|
|
||||||
from fastapi import Request, Response
|
from fastapi import Request, Response
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
@@ -19,7 +21,7 @@ CSRF_TOKEN_LENGTH = 64 # bytes
|
|||||||
|
|
||||||
def is_secure_request(request: Request) -> bool:
|
def is_secure_request(request: Request) -> bool:
|
||||||
"""Detect whether the original client request was made over HTTPS."""
|
"""Detect whether the original client request was made over HTTPS."""
|
||||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
return _request_scheme(request) == "https"
|
||||||
|
|
||||||
|
|
||||||
def generate_csrf_token() -> str:
|
def generate_csrf_token() -> str:
|
||||||
@@ -61,6 +63,109 @@ def is_auth_endpoint(request: Request) -> bool:
|
|||||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||||
|
|
||||||
|
|
||||||
|
def _host_with_optional_port(hostname: str, port: int | None, scheme: str) -> str:
|
||||||
|
"""Return normalized host[:port], omitting default ports."""
|
||||||
|
host = hostname.lower()
|
||||||
|
if ":" in host and not host.startswith("["):
|
||||||
|
host = f"[{host}]"
|
||||||
|
|
||||||
|
if port is None or (scheme == "http" and port == 80) or (scheme == "https" and port == 443):
|
||||||
|
return host
|
||||||
|
return f"{host}:{port}"
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_origin(origin: str) -> str | None:
|
||||||
|
"""Return a normalized scheme://host[:port] origin, or None for invalid input."""
|
||||||
|
try:
|
||||||
|
parsed = urlsplit(origin.strip())
|
||||||
|
port = parsed.port
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
scheme = parsed.scheme.lower()
|
||||||
|
if scheme not in {"http", "https"} or not parsed.hostname:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Browser Origin is only scheme/host/port. Reject URL-shaped or credentialed values.
|
||||||
|
if parsed.username or parsed.password or parsed.path or parsed.query or parsed.fragment:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return f"{scheme}://{_host_with_optional_port(parsed.hostname, port, scheme)}"
|
||||||
|
|
||||||
|
|
||||||
|
def _configured_cors_origins() -> set[str]:
|
||||||
|
"""Return explicit configured browser origins that may call auth routes."""
|
||||||
|
origins = set()
|
||||||
|
for raw_origin in os.environ.get("GATEWAY_CORS_ORIGINS", "").split(","):
|
||||||
|
origin = raw_origin.strip()
|
||||||
|
if not origin or origin == "*":
|
||||||
|
continue
|
||||||
|
normalized = _normalize_origin(origin)
|
||||||
|
if normalized:
|
||||||
|
origins.add(normalized)
|
||||||
|
return origins
|
||||||
|
|
||||||
|
|
||||||
|
def _first_header_value(value: str | None) -> str | None:
|
||||||
|
"""Return the first value from a comma-separated proxy header."""
|
||||||
|
if not value:
|
||||||
|
return None
|
||||||
|
first = value.split(",", 1)[0].strip()
|
||||||
|
return first or None
|
||||||
|
|
||||||
|
|
||||||
|
def _forwarded_param(request: Request, name: str) -> str | None:
|
||||||
|
"""Extract a parameter from the first RFC 7239 Forwarded header entry."""
|
||||||
|
forwarded = _first_header_value(request.headers.get("forwarded"))
|
||||||
|
if not forwarded:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for part in forwarded.split(";"):
|
||||||
|
key, sep, value = part.strip().partition("=")
|
||||||
|
if sep and key.lower() == name:
|
||||||
|
return value.strip().strip('"') or None
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _request_scheme(request: Request) -> str:
|
||||||
|
"""Resolve the original request scheme from trusted proxy headers."""
|
||||||
|
scheme = _forwarded_param(request, "proto") or _first_header_value(request.headers.get("x-forwarded-proto")) or request.url.scheme
|
||||||
|
return scheme.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _request_origin(request: Request) -> str | None:
|
||||||
|
"""Build the origin for the URL the browser is targeting."""
|
||||||
|
scheme = _request_scheme(request)
|
||||||
|
host = _forwarded_param(request, "host") or _first_header_value(request.headers.get("x-forwarded-host")) or request.headers.get("host") or request.url.netloc
|
||||||
|
|
||||||
|
forwarded_port = _first_header_value(request.headers.get("x-forwarded-port"))
|
||||||
|
if forwarded_port and ":" not in host.rsplit("]", 1)[-1]:
|
||||||
|
host = f"{host}:{forwarded_port}"
|
||||||
|
|
||||||
|
return _normalize_origin(f"{scheme}://{host}")
|
||||||
|
|
||||||
|
|
||||||
|
def is_allowed_auth_origin(request: Request) -> bool:
|
||||||
|
"""Allow auth POSTs only from the same origin or explicit configured origins.
|
||||||
|
|
||||||
|
Login/register/initialize are exempt from the double-submit token because
|
||||||
|
first-time browser clients do not have a CSRF token yet. They still create
|
||||||
|
a session cookie, so browser requests with a hostile Origin header must be
|
||||||
|
rejected to prevent login CSRF / session fixation. Requests without Origin
|
||||||
|
are allowed for non-browser clients such as curl and mobile integrations.
|
||||||
|
"""
|
||||||
|
origin = request.headers.get("origin")
|
||||||
|
if not origin:
|
||||||
|
return True
|
||||||
|
|
||||||
|
normalized_origin = _normalize_origin(origin)
|
||||||
|
if normalized_origin is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
request_origin = _request_origin(request)
|
||||||
|
return normalized_origin in _configured_cors_origins() or (request_origin is not None and normalized_origin == request_origin)
|
||||||
|
|
||||||
|
|
||||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||||
|
|
||||||
@@ -70,6 +175,12 @@ class CSRFMiddleware(BaseHTTPMiddleware):
|
|||||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
_is_auth = is_auth_endpoint(request)
|
_is_auth = is_auth_endpoint(request)
|
||||||
|
|
||||||
|
if should_check_csrf(request) and _is_auth and not is_allowed_auth_origin(request):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "Cross-site auth request denied."},
|
||||||
|
)
|
||||||
|
|
||||||
if should_check_csrf(request) and not _is_auth:
|
if should_check_csrf(request) and not _is_auth:
|
||||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||||
|
|||||||
@@ -0,0 +1,219 @@
|
|||||||
|
"""Tests for CSRF middleware."""
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app() -> FastAPI:
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(CSRFMiddleware)
|
||||||
|
|
||||||
|
@app.post("/api/v1/auth/login/local")
|
||||||
|
async def login_local():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.post("/api/v1/auth/register")
|
||||||
|
async def register():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.post("/api/threads/abc/runs/stream")
|
||||||
|
async def protected_mutation():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_rejects_cross_origin_browser_request():
|
||||||
|
"""CSRF-exempt auth routes must not accept hostile browser origins.
|
||||||
|
|
||||||
|
Login/register endpoints intentionally skip the double-submit token because
|
||||||
|
first-time callers do not have a token yet. They still set an auth session,
|
||||||
|
so a hostile cross-site form POST must be rejected to avoid login CSRF /
|
||||||
|
session fixation.
|
||||||
|
"""
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
headers={"Origin": "https://evil.example"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_allows_same_origin_browser_request():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
headers={"Origin": "https://deerflow.example"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.cookies.get("csrf_token")
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_rejects_malformed_origin_with_path():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
headers={"Origin": "https://deerflow.example/path"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||||
|
assert response.cookies.get("csrf_token") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_rejects_malformed_origin_with_invalid_port():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
headers={"Origin": "https://deerflow.example:bad"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||||
|
assert response.cookies.get("csrf_token") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_allows_same_origin_default_port_equivalence():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
headers={"Origin": "https://deerflow.example:443"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.cookies.get("csrf_token")
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_allows_forwarded_same_origin():
|
||||||
|
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
headers={
|
||||||
|
"Origin": "https://deerflow.example",
|
||||||
|
"X-Forwarded-Proto": "https",
|
||||||
|
"X-Forwarded-Host": "deerflow.example, internal:8000",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.cookies.get("csrf_token")
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_allows_rfc_forwarded_same_origin():
|
||||||
|
client = TestClient(_make_app(), base_url="http://internal:8000")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
headers={
|
||||||
|
"Origin": "https://deerflow.example",
|
||||||
|
"Forwarded": "proto=https;host=deerflow.example",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.cookies.get("csrf_token")
|
||||||
|
assert "secure" in response.headers["set-cookie"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_allows_explicit_configured_origin(monkeypatch):
|
||||||
|
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "https://app.example")
|
||||||
|
client = TestClient(_make_app(), base_url="https://api.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
headers={"Origin": "https://app.example"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.cookies.get("csrf_token")
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_does_not_treat_wildcard_cors_as_allowed_origin(monkeypatch):
|
||||||
|
monkeypatch.setenv("GATEWAY_CORS_ORIGINS", "*")
|
||||||
|
client = TestClient(_make_app(), base_url="https://api.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
headers={"Origin": "https://evil.example"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "Cross-site auth request denied."
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_sets_strict_samesite_csrf_cookie():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
headers={"Origin": "https://deerflow.example"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
set_cookie = response.headers["set-cookie"].lower()
|
||||||
|
assert "csrf_token=" in set_cookie
|
||||||
|
assert "samesite=strict" in set_cookie
|
||||||
|
assert "secure" in set_cookie
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_post_without_origin_still_allows_non_browser_clients():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
|
||||||
|
response = client.post("/api/v1/auth/login/local")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.cookies.get("csrf_token")
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_auth_mutation_still_requires_double_submit_token():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/threads/abc/runs/stream",
|
||||||
|
headers={"Origin": "https://deerflow.example"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_auth_mutation_allows_valid_double_submit_token():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
client.cookies.set("csrf_token", "known-token")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/threads/abc/runs/stream",
|
||||||
|
headers={
|
||||||
|
"Origin": "https://deerflow.example",
|
||||||
|
"X-CSRF-Token": "known-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_non_auth_mutation_rejects_mismatched_double_submit_token():
|
||||||
|
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||||
|
client.cookies.set("csrf_token", "cookie-token")
|
||||||
|
|
||||||
|
response = client.post(
|
||||||
|
"/api/threads/abc/runs/stream",
|
||||||
|
headers={
|
||||||
|
"Origin": "https://deerflow.example",
|
||||||
|
"X-CSRF-Token": "header-token",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert response.json()["detail"] == "CSRF token mismatch."
|
||||||
Reference in New Issue
Block a user