feat(auth): wire auth end-to-end (middleware + frontend replacement)
Backend: - Port auth_middleware, csrf_middleware, langgraph_auth, routers/auth - Port authz decorator (owner_filter_key defaults to 'owner_id') - Merge app.py: register AuthMiddleware + CSRFMiddleware + CORS, add _ensure_admin_user lifespan hook, _migrate_orphaned_threads helper, register auth router - Merge deps.py: add get_local_provider, get_current_user_from_request, get_optional_user_from_request; keep get_current_user as thin str|None adapter for feedback router - langgraph.json: add auth path pointing to langgraph_auth.py:auth - Rename metadata['user_id'] -> metadata['owner_id'] in langgraph_auth (both metadata write and LangGraph filter dict) + test fixtures Frontend: - Delete better-auth library and api catch-all route - Remove better-auth npm dependency and env vars (BETTER_AUTH_SECRET, BETTER_AUTH_GITHUB_*) from env.js - Port frontend/src/core/auth/* (AuthProvider, gateway-config, proxy-policy, server-side getServerSideUser, types) - Port frontend/src/core/api/fetcher.ts - Port (auth)/layout, (auth)/login, (auth)/setup pages - Rewrite workspace/layout.tsx as server component that calls getServerSideUser and wraps in AuthProvider - Port workspace/workspace-content.tsx for the client-side sidebar logic Tests: - Port 5 auth test files (test_auth, test_auth_middleware, test_auth_type_system, test_ensure_admin, test_langgraph_auth) - 176 auth tests PASS After this commit: login/logout/registration flow works, but persistence layer does not yet filter by owner_id. Commit 4 closes that gap.
This commit is contained in:
+128
-1
@@ -1,15 +1,21 @@
|
|||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from datetime import UTC
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.gateway.auth_middleware import AuthMiddleware
|
||||||
from app.gateway.config import get_gateway_config
|
from app.gateway.config import get_gateway_config
|
||||||
|
from app.gateway.csrf_middleware import CSRFMiddleware
|
||||||
from app.gateway.deps import langgraph_runtime
|
from app.gateway.deps import langgraph_runtime
|
||||||
from app.gateway.routers import (
|
from app.gateway.routers import (
|
||||||
agents,
|
agents,
|
||||||
artifacts,
|
artifacts,
|
||||||
assistants_compat,
|
assistants_compat,
|
||||||
|
auth,
|
||||||
channels,
|
channels,
|
||||||
feedback,
|
feedback,
|
||||||
mcp,
|
mcp,
|
||||||
@@ -34,6 +40,92 @@ logging.basicConfig(
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||||
|
"""Auto-create the admin user on first boot if no users exist.
|
||||||
|
|
||||||
|
Prints the generated password to stdout so the operator can log in.
|
||||||
|
On subsequent boots, warns if any user still needs setup.
|
||||||
|
|
||||||
|
Multi-worker safe: relies on SQLite UNIQUE constraint to resolve races.
|
||||||
|
Only the worker that successfully creates/updates the admin prints the
|
||||||
|
password; losers silently skip.
|
||||||
|
"""
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from app.gateway.deps import get_local_provider
|
||||||
|
|
||||||
|
provider = get_local_provider()
|
||||||
|
user_count = await provider.count_users()
|
||||||
|
|
||||||
|
if user_count == 0:
|
||||||
|
password = secrets.token_urlsafe(16)
|
||||||
|
try:
|
||||||
|
admin = await provider.create_user(email="admin@deerflow.dev", password=password, system_role="admin", needs_setup=True)
|
||||||
|
except ValueError:
|
||||||
|
return # Another worker already created the admin.
|
||||||
|
|
||||||
|
# Migrate orphaned threads (no owner_id) to this admin
|
||||||
|
store = getattr(app.state, "store", None)
|
||||||
|
if store is not None:
|
||||||
|
await _migrate_orphaned_threads(store, str(admin.id))
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(" Admin account created on first boot")
|
||||||
|
logger.info(" Email: %s", admin.email)
|
||||||
|
logger.info(" Password: %s", password)
|
||||||
|
logger.info(" Change it after login: Settings -> Account")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Admin exists but setup never completed — reset password so operator
|
||||||
|
# can always find it in the console without needing the CLI.
|
||||||
|
# Multi-worker guard: if admin was created less than 30s ago, another
|
||||||
|
# worker just created it and will print the password — skip reset.
|
||||||
|
admin = await provider.get_user_by_email("admin@deerflow.dev")
|
||||||
|
if admin and admin.needs_setup:
|
||||||
|
import time
|
||||||
|
|
||||||
|
age = time.time() - admin.created_at.replace(tzinfo=UTC).timestamp()
|
||||||
|
if age < 30:
|
||||||
|
return # Just created by another worker in this startup; its password is still valid.
|
||||||
|
|
||||||
|
from app.gateway.auth.password import hash_password_async
|
||||||
|
|
||||||
|
password = secrets.token_urlsafe(16)
|
||||||
|
admin.password_hash = await hash_password_async(password)
|
||||||
|
admin.token_version += 1
|
||||||
|
await provider.update_user(admin)
|
||||||
|
|
||||||
|
logger.info("=" * 60)
|
||||||
|
logger.info(" Admin account setup incomplete — password reset")
|
||||||
|
logger.info(" Email: %s", admin.email)
|
||||||
|
logger.info(" Password: %s", password)
|
||||||
|
logger.info(" Change it after login: Settings -> Account")
|
||||||
|
logger.info("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
async def _migrate_orphaned_threads(store, admin_user_id: str) -> None:
|
||||||
|
"""Migrate threads with no owner_id to the given admin.
|
||||||
|
|
||||||
|
NOTE: This is the initial port. Commit 5 will replace the hardcoded
|
||||||
|
limit=1000 with cursor pagination and extend to SQL persistence tables.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
migrated = 0
|
||||||
|
results = await store.asearch(("threads",), limit=1000)
|
||||||
|
for item in results:
|
||||||
|
metadata = item.value.get("metadata", {})
|
||||||
|
if not metadata.get("owner_id"):
|
||||||
|
metadata["owner_id"] = admin_user_id
|
||||||
|
item.value["metadata"] = metadata
|
||||||
|
await store.aput(("threads",), item.key, item.value)
|
||||||
|
migrated += 1
|
||||||
|
if migrated:
|
||||||
|
logger.info("Migrated %d orphaned thread(s) to admin", migrated)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Thread migration failed (non-fatal)")
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
"""Application lifespan handler."""
|
"""Application lifespan handler."""
|
||||||
@@ -53,6 +145,10 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
async with langgraph_runtime(app):
|
async with langgraph_runtime(app):
|
||||||
logger.info("LangGraph runtime initialised")
|
logger.info("LangGraph runtime initialised")
|
||||||
|
|
||||||
|
# Ensure admin user exists (auto-create on first boot)
|
||||||
|
# Must run AFTER langgraph_runtime so app.state.store is available for thread migration
|
||||||
|
await _ensure_admin_user(app)
|
||||||
|
|
||||||
# Start IM channel service if any channels are configured
|
# Start IM channel service if any channels are configured
|
||||||
try:
|
try:
|
||||||
from app.channels.service import start_channel_service
|
from app.channels.service import start_channel_service
|
||||||
@@ -164,7 +260,35 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# CORS is handled by nginx - no need for FastAPI middleware
|
# Auth: reject unauthenticated requests to non-public paths (fail-closed safety net)
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
|
||||||
|
# CSRF: Double Submit Cookie pattern for state-changing requests
|
||||||
|
app.add_middleware(CSRFMiddleware)
|
||||||
|
|
||||||
|
# CORS: when GATEWAY_CORS_ORIGINS is set (dev without nginx), add CORS middleware.
|
||||||
|
# In production, nginx handles CORS and no middleware is needed.
|
||||||
|
cors_origins_env = os.environ.get("GATEWAY_CORS_ORIGINS", "")
|
||||||
|
if cors_origins_env:
|
||||||
|
cors_origins = [o.strip() for o in cors_origins_env.split(",") if o.strip()]
|
||||||
|
# Validate: wildcard origin with credentials is a security misconfiguration
|
||||||
|
for origin in cors_origins:
|
||||||
|
if origin == "*":
|
||||||
|
logger.error(
|
||||||
|
"GATEWAY_CORS_ORIGINS contains wildcard '*' with allow_credentials=True. "
|
||||||
|
"This is a security misconfiguration — browsers will reject the response. "
|
||||||
|
"Use explicit scheme://host:port origins instead."
|
||||||
|
)
|
||||||
|
cors_origins = [o for o in cors_origins if o != "*"]
|
||||||
|
break
|
||||||
|
if cors_origins:
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=cors_origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
# Include routers
|
# Include routers
|
||||||
# Models API is mounted at /api/models
|
# Models API is mounted at /api/models
|
||||||
@@ -200,6 +324,9 @@ This gateway provides custom endpoints for models, MCP configuration, skills, an
|
|||||||
# Assistants compatibility API (LangGraph Platform stub)
|
# Assistants compatibility API (LangGraph Platform stub)
|
||||||
app.include_router(assistants_compat.router)
|
app.include_router(assistants_compat.router)
|
||||||
|
|
||||||
|
# Auth API is mounted at /api/v1/auth
|
||||||
|
app.include_router(auth.router)
|
||||||
|
|
||||||
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
|
# Feedback API is mounted at /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||||
app.include_router(feedback.router)
|
app.include_router(feedback.router)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,71 @@
|
|||||||
|
"""Global authentication middleware — fail-closed safety net.
|
||||||
|
|
||||||
|
Rejects unauthenticated requests to non-public paths with 401.
|
||||||
|
Fine-grained permission checks remain in authz.py decorators.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from app.gateway.auth.errors import AuthErrorCode
|
||||||
|
|
||||||
|
# Paths that never require authentication.
|
||||||
|
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
||||||
|
"/health",
|
||||||
|
"/docs",
|
||||||
|
"/redoc",
|
||||||
|
"/openapi.json",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exact auth paths that are public (login/register/status check).
|
||||||
|
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
|
||||||
|
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/logout",
|
||||||
|
"/api/v1/auth/setup-status",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_public(path: str) -> bool:
|
||||||
|
stripped = path.rstrip("/")
|
||||||
|
if stripped in _PUBLIC_EXACT_PATHS:
|
||||||
|
return True
|
||||||
|
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Coarse-grained auth gate: reject requests without a valid session cookie.
|
||||||
|
|
||||||
|
This does NOT verify JWT signature or user existence — that is the job of
|
||||||
|
``get_current_user_from_request`` in deps.py (called by ``@require_auth``).
|
||||||
|
The middleware only checks *presence* of the cookie so that new endpoints
|
||||||
|
that forget ``@require_auth`` are not completely exposed.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
if _is_public(request.url.path):
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Non-public path: require session cookie
|
||||||
|
if not request.cookies.get("access_token"):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={
|
||||||
|
"detail": {
|
||||||
|
"code": AuthErrorCode.NOT_AUTHENTICATED,
|
||||||
|
"message": "Authentication required",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
@@ -0,0 +1,261 @@
|
|||||||
|
"""Authorization decorators and context for DeerFlow.
|
||||||
|
|
||||||
|
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
|
||||||
|
|
||||||
|
**Usage:**
|
||||||
|
|
||||||
|
1. Use ``@require_auth`` on routes that need authentication
|
||||||
|
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
|
||||||
|
3. The decorator chain processes from bottom to top
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
|
||||||
|
@router.get("/{thread_id}")
|
||||||
|
@require_auth
|
||||||
|
@require_permission("threads", "read", owner_check=True)
|
||||||
|
async def get_thread(thread_id: str, request: Request):
|
||||||
|
# User is authenticated and has threads:read permission
|
||||||
|
...
|
||||||
|
|
||||||
|
**Permission Model:**
|
||||||
|
|
||||||
|
- threads:read - View thread
|
||||||
|
- threads:write - Create/update thread
|
||||||
|
- threads:delete - Delete thread
|
||||||
|
- runs:create - Run agent
|
||||||
|
- runs:read - View run
|
||||||
|
- runs:cancel - Cancel run
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import functools
|
||||||
|
from collections.abc import Callable
|
||||||
|
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.gateway.auth.models import User
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
# Permission constants
|
||||||
|
class Permissions:
|
||||||
|
"""Permission constants for resource:action format."""
|
||||||
|
|
||||||
|
# Threads
|
||||||
|
THREADS_READ = "threads:read"
|
||||||
|
THREADS_WRITE = "threads:write"
|
||||||
|
THREADS_DELETE = "threads:delete"
|
||||||
|
|
||||||
|
# Runs
|
||||||
|
RUNS_CREATE = "runs:create"
|
||||||
|
RUNS_READ = "runs:read"
|
||||||
|
RUNS_CANCEL = "runs:cancel"
|
||||||
|
|
||||||
|
|
||||||
|
class AuthContext:
|
||||||
|
"""Authentication context for the current request.
|
||||||
|
|
||||||
|
Stored in request.state.auth after require_auth decoration.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
user: The authenticated user, or None if anonymous
|
||||||
|
permissions: List of permission strings (e.g., "threads:read")
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("user", "permissions")
|
||||||
|
|
||||||
|
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
||||||
|
self.user = user
|
||||||
|
self.permissions = permissions or []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_authenticated(self) -> bool:
|
||||||
|
"""Check if user is authenticated."""
|
||||||
|
return self.user is not None
|
||||||
|
|
||||||
|
def has_permission(self, resource: str, action: str) -> bool:
|
||||||
|
"""Check if context has permission for resource:action.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource: Resource name (e.g., "threads")
|
||||||
|
action: Action name (e.g., "read")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if user has permission
|
||||||
|
"""
|
||||||
|
permission = f"{resource}:{action}"
|
||||||
|
return permission in self.permissions
|
||||||
|
|
||||||
|
def require_user(self) -> User:
|
||||||
|
"""Get user or raise 401.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 401 if not authenticated
|
||||||
|
"""
|
||||||
|
if not self.user:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
return self.user
|
||||||
|
|
||||||
|
|
||||||
|
def get_auth_context(request: Request) -> AuthContext | None:
|
||||||
|
"""Get AuthContext from request state."""
|
||||||
|
return getattr(request.state, "auth", None)
|
||||||
|
|
||||||
|
|
||||||
|
_ALL_PERMISSIONS: list[str] = [
|
||||||
|
Permissions.THREADS_READ,
|
||||||
|
Permissions.THREADS_WRITE,
|
||||||
|
Permissions.THREADS_DELETE,
|
||||||
|
Permissions.RUNS_CREATE,
|
||||||
|
Permissions.RUNS_READ,
|
||||||
|
Permissions.RUNS_CANCEL,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def _authenticate(request: Request) -> AuthContext:
|
||||||
|
"""Authenticate request and return AuthContext.
|
||||||
|
|
||||||
|
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
|
||||||
|
Returns AuthContext with user=None for anonymous requests.
|
||||||
|
"""
|
||||||
|
from app.gateway.deps import get_optional_user_from_request
|
||||||
|
|
||||||
|
user = await get_optional_user_from_request(request)
|
||||||
|
if user is None:
|
||||||
|
return AuthContext(user=None, permissions=[])
|
||||||
|
|
||||||
|
# In future, permissions could be stored in user record
|
||||||
|
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||||
|
|
||||||
|
|
||||||
|
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
||||||
|
"""Decorator that authenticates the request and sets AuthContext.
|
||||||
|
|
||||||
|
Must be placed ABOVE other decorators (executes after them).
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
@router.get("/{thread_id}")
|
||||||
|
@require_auth # Bottom decorator (executes first after permission check)
|
||||||
|
@require_permission("threads", "read")
|
||||||
|
async def get_thread(thread_id: str, request: Request):
|
||||||
|
auth: AuthContext = request.state.auth
|
||||||
|
...
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If 'request' parameter is missing
|
||||||
|
"""
|
||||||
|
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
request = kwargs.get("request")
|
||||||
|
if request is None:
|
||||||
|
raise ValueError("require_auth decorator requires 'request' parameter")
|
||||||
|
|
||||||
|
# Authenticate and set context
|
||||||
|
auth_context = await _authenticate(request)
|
||||||
|
request.state.auth = auth_context
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def require_permission(
|
||||||
|
resource: str,
|
||||||
|
action: str,
|
||||||
|
owner_check: bool = False,
|
||||||
|
owner_filter_key: str = "owner_id",
|
||||||
|
inject_record: bool = False,
|
||||||
|
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
||||||
|
"""Decorator that checks permission for resource:action.
|
||||||
|
|
||||||
|
Must be used AFTER @require_auth.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
resource: Resource name (e.g., "threads", "runs")
|
||||||
|
action: Action name (e.g., "read", "write", "delete")
|
||||||
|
owner_check: If True, validates that the current user owns the resource.
|
||||||
|
Requires 'thread_id' path parameter and performs ownership check.
|
||||||
|
owner_filter_key: Field name for ownership filter (default: "owner_id")
|
||||||
|
inject_record: If True and owner_check is True, injects the thread record
|
||||||
|
into kwargs['thread_record'] for use in the handler.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
# Simple permission check
|
||||||
|
@require_permission("threads", "read")
|
||||||
|
async def get_thread(thread_id: str, request: Request):
|
||||||
|
...
|
||||||
|
|
||||||
|
# With ownership check (for /threads/{thread_id} endpoints)
|
||||||
|
@require_permission("threads", "delete", owner_check=True)
|
||||||
|
async def delete_thread(thread_id: str, request: Request):
|
||||||
|
...
|
||||||
|
|
||||||
|
# With ownership check and record injection
|
||||||
|
@require_permission("threads", "delete", owner_check=True, inject_record=True)
|
||||||
|
async def delete_thread(thread_id: str, request: Request, thread_record: dict = None):
|
||||||
|
# thread_record is injected if found
|
||||||
|
...
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
HTTPException 401: If authentication required but user is anonymous
|
||||||
|
HTTPException 403: If user lacks permission
|
||||||
|
HTTPException 404: If owner_check=True but user doesn't own the thread
|
||||||
|
ValueError: If owner_check=True but 'thread_id' parameter is missing
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
||||||
|
@functools.wraps(func)
|
||||||
|
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||||
|
request = kwargs.get("request")
|
||||||
|
if request is None:
|
||||||
|
raise ValueError("require_permission decorator requires 'request' parameter")
|
||||||
|
|
||||||
|
auth: AuthContext = getattr(request.state, "auth", None)
|
||||||
|
if auth is None:
|
||||||
|
auth = await _authenticate(request)
|
||||||
|
request.state.auth = auth
|
||||||
|
|
||||||
|
if not auth.is_authenticated:
|
||||||
|
raise HTTPException(status_code=401, detail="Authentication required")
|
||||||
|
|
||||||
|
# Check permission
|
||||||
|
if not auth.has_permission(resource, action):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"Permission denied: {resource}:{action}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Owner check for thread-specific resources
|
||||||
|
if owner_check:
|
||||||
|
thread_id = kwargs.get("thread_id")
|
||||||
|
if thread_id is None:
|
||||||
|
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||||
|
|
||||||
|
# Get thread and verify ownership
|
||||||
|
from app.gateway.routers.threads import _store_get, get_store
|
||||||
|
|
||||||
|
store = get_store(request)
|
||||||
|
if store is not None:
|
||||||
|
record = await _store_get(store, thread_id)
|
||||||
|
if record:
|
||||||
|
owner_id = record.get("metadata", {}).get(owner_filter_key)
|
||||||
|
if owner_id and owner_id != str(auth.user.id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail=f"Thread {thread_id} not found",
|
||||||
|
)
|
||||||
|
# Inject record if requested
|
||||||
|
if inject_record:
|
||||||
|
kwargs["thread_record"] = record
|
||||||
|
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
return decorator
|
||||||
@@ -0,0 +1,112 @@
|
|||||||
|
"""CSRF protection middleware for FastAPI.
|
||||||
|
|
||||||
|
Per RFC-001:
|
||||||
|
State-changing operations require CSRF protection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
CSRF_COOKIE_NAME = "csrf_token"
|
||||||
|
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||||
|
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||||
|
|
||||||
|
|
||||||
|
def is_secure_request(request: Request) -> bool:
|
||||||
|
"""Detect whether the original client request was made over HTTPS."""
|
||||||
|
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||||
|
|
||||||
|
|
||||||
|
def generate_csrf_token() -> str:
|
||||||
|
"""Generate a secure random CSRF token."""
|
||||||
|
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
|
def should_check_csrf(request: Request) -> bool:
|
||||||
|
"""Determine if a request needs CSRF validation.
|
||||||
|
|
||||||
|
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
|
||||||
|
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
|
||||||
|
"""
|
||||||
|
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||||
|
return False
|
||||||
|
|
||||||
|
path = request.url.path.rstrip("/")
|
||||||
|
# Exempt /api/v1/auth/me endpoint
|
||||||
|
if path == "/api/v1/auth/me":
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
"/api/v1/auth/logout",
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_auth_endpoint(request: Request) -> bool:
|
||||||
|
"""Check if the request is to an auth endpoint.
|
||||||
|
|
||||||
|
Auth endpoints don't need CSRF validation on first call (no token).
|
||||||
|
"""
|
||||||
|
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||||
|
|
||||||
|
|
||||||
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||||
|
_is_auth = is_auth_endpoint(request)
|
||||||
|
|
||||||
|
if should_check_csrf(request) and not _is_auth:
|
||||||
|
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||||
|
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||||
|
|
||||||
|
if not cookie_token or not header_token:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
|
||||||
|
)
|
||||||
|
|
||||||
|
if not secrets.compare_digest(cookie_token, header_token):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=403,
|
||||||
|
content={"detail": "CSRF token mismatch."},
|
||||||
|
)
|
||||||
|
|
||||||
|
response = await call_next(request)
|
||||||
|
|
||||||
|
# For auth endpoints that set up session, also set CSRF cookie
|
||||||
|
if _is_auth and request.method == "POST":
|
||||||
|
# Generate a new CSRF token for the session
|
||||||
|
csrf_token = generate_csrf_token()
|
||||||
|
is_https = is_secure_request(request)
|
||||||
|
response.set_cookie(
|
||||||
|
key=CSRF_COOKIE_NAME,
|
||||||
|
value=csrf_token,
|
||||||
|
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
|
||||||
|
secure=is_https,
|
||||||
|
samesite="strict",
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def get_csrf_token(request: Request) -> str | None:
|
||||||
|
"""Get the CSRF token from the current request's cookies.
|
||||||
|
|
||||||
|
This is useful for server-side rendering where you need to embed
|
||||||
|
token in forms or headers.
|
||||||
|
"""
|
||||||
|
return request.cookies.get(CSRF_COOKIE_NAME)
|
||||||
@@ -11,11 +11,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
from contextlib import AsyncExitStack, asynccontextmanager
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
|
||||||
from deerflow.runtime import RunContext, RunManager
|
from deerflow.runtime import RunContext, RunManager
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||||
|
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
@@ -127,10 +132,86 @@ def get_run_context(request: Request) -> RunContext:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> str | None:
|
# ---------------------------------------------------------------------------
|
||||||
"""Extract user identity from request.
|
# Auth helpers (used by authz.py and auth middleware)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
Phase 2: always returns None (no authentication).
|
# Cached singletons to avoid repeated instantiation per request
|
||||||
Phase 3: extract user_id from JWT / session / API key header.
|
_cached_local_provider: LocalAuthProvider | None = None
|
||||||
|
_cached_repo: SQLiteUserRepository | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_local_provider() -> LocalAuthProvider:
|
||||||
|
"""Get or create the cached LocalAuthProvider singleton."""
|
||||||
|
global _cached_local_provider, _cached_repo
|
||||||
|
if _cached_repo is None:
|
||||||
|
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||||
|
|
||||||
|
_cached_repo = SQLiteUserRepository()
|
||||||
|
if _cached_local_provider is None:
|
||||||
|
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||||
|
|
||||||
|
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
|
||||||
|
return _cached_local_provider
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user_from_request(request: Request):
|
||||||
|
"""Get the current authenticated user from the request cookie.
|
||||||
|
|
||||||
|
Raises HTTPException 401 if not authenticated.
|
||||||
"""
|
"""
|
||||||
return None
|
from app.gateway.auth import decode_token
|
||||||
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||||
|
|
||||||
|
access_token = request.cookies.get("access_token")
|
||||||
|
if not access_token:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = decode_token(access_token)
|
||||||
|
if isinstance(payload, TokenError):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = get_local_provider()
|
||||||
|
user = await provider.get_user(payload.sub)
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Token version mismatch → password was changed, token is stale
|
||||||
|
if user.token_version != payload.ver:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
async def get_optional_user_from_request(request: Request):
|
||||||
|
"""Get optional authenticated user from request.
|
||||||
|
|
||||||
|
Returns None if not authenticated.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return await get_current_user_from_request(request)
|
||||||
|
except HTTPException:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(request: Request) -> str | None:
|
||||||
|
"""Extract user_id from request cookie, or None if not authenticated.
|
||||||
|
|
||||||
|
Thin adapter that returns the string id for callers that only need
|
||||||
|
identification (e.g., ``feedback.py``). Full-user callers should use
|
||||||
|
``get_current_user_from_request`` or ``get_optional_user_from_request``.
|
||||||
|
"""
|
||||||
|
user = await get_optional_user_from_request(request)
|
||||||
|
return str(user.id) if user else None
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
"""LangGraph Server auth handler — shares JWT logic with Gateway.
|
||||||
|
|
||||||
|
Loaded by LangGraph Server via langgraph.json ``auth.path``.
|
||||||
|
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
|
||||||
|
so both modes validate tokens with the same secret and rules.
|
||||||
|
|
||||||
|
Two layers:
|
||||||
|
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
||||||
|
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
|
||||||
|
2. @auth.on — returns metadata filter so each user only sees own threads
|
||||||
|
"""
|
||||||
|
|
||||||
|
import secrets
|
||||||
|
|
||||||
|
from langgraph_sdk import Auth
|
||||||
|
|
||||||
|
from app.gateway.auth.errors import TokenError
|
||||||
|
from app.gateway.auth.jwt import decode_token
|
||||||
|
from app.gateway.deps import get_local_provider
|
||||||
|
|
||||||
|
auth = Auth()
|
||||||
|
|
||||||
|
# Methods that require CSRF validation (state-changing per RFC 7231).
|
||||||
|
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||||
|
|
||||||
|
|
||||||
|
def _check_csrf(request) -> None:
|
||||||
|
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
|
||||||
|
|
||||||
|
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
|
||||||
|
proxied directly by nginx have the same CSRF protection.
|
||||||
|
"""
|
||||||
|
method = getattr(request, "method", "") or ""
|
||||||
|
if method.upper() not in _CSRF_METHODS:
|
||||||
|
return
|
||||||
|
|
||||||
|
cookie_token = request.cookies.get("csrf_token")
|
||||||
|
header_token = request.headers.get("x-csrf-token")
|
||||||
|
|
||||||
|
if not cookie_token or not header_token:
|
||||||
|
raise Auth.exceptions.HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="CSRF token missing. Include X-CSRF-Token header.",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not secrets.compare_digest(cookie_token, header_token):
|
||||||
|
raise Auth.exceptions.HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail="CSRF token mismatch.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@auth.authenticate
|
||||||
|
async def authenticate(request):
|
||||||
|
"""Validate the session cookie, decode JWT, and check token_version.
|
||||||
|
|
||||||
|
Same validation chain as Gateway's get_current_user_from_request:
|
||||||
|
cookie → decode JWT → DB lookup → token_version match
|
||||||
|
Also enforces CSRF on state-changing methods.
|
||||||
|
"""
|
||||||
|
# CSRF check before authentication so forged cross-site requests
|
||||||
|
# are rejected early, even if the cookie carries a valid JWT.
|
||||||
|
_check_csrf(request)
|
||||||
|
|
||||||
|
token = request.cookies.get("access_token")
|
||||||
|
if not token:
|
||||||
|
raise Auth.exceptions.HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Not authenticated",
|
||||||
|
)
|
||||||
|
|
||||||
|
payload = decode_token(token)
|
||||||
|
if isinstance(payload, TokenError):
|
||||||
|
raise Auth.exceptions.HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail=f"Token error: {payload.value}",
|
||||||
|
)
|
||||||
|
|
||||||
|
user = await get_local_provider().get_user(payload.sub)
|
||||||
|
if user is None:
|
||||||
|
raise Auth.exceptions.HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
if user.token_version != payload.ver:
|
||||||
|
raise Auth.exceptions.HTTPException(
|
||||||
|
status_code=401,
|
||||||
|
detail="Token revoked (password changed)",
|
||||||
|
)
|
||||||
|
|
||||||
|
return payload.sub
|
||||||
|
|
||||||
|
|
||||||
|
@auth.on
|
||||||
|
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||||
|
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
||||||
|
|
||||||
|
Gateway stores thread ownership as ``metadata.owner_id``.
|
||||||
|
This handler ensures LangGraph Server enforces the same isolation.
|
||||||
|
"""
|
||||||
|
# On create/update: stamp owner_id into metadata
|
||||||
|
metadata = value.setdefault("metadata", {})
|
||||||
|
metadata["owner_id"] = ctx.user.identity
|
||||||
|
|
||||||
|
# Return filter dict — LangGraph applies it to search/read/delete
|
||||||
|
return {"owner_id": ctx.user.identity}
|
||||||
@@ -0,0 +1,303 @@
|
|||||||
|
"""Authentication endpoints."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
|
from pydantic import BaseModel, EmailStr, Field
|
||||||
|
|
||||||
|
from app.gateway.auth import (
|
||||||
|
UserResponse,
|
||||||
|
create_access_token,
|
||||||
|
)
|
||||||
|
from app.gateway.auth.config import get_auth_config
|
||||||
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
||||||
|
from app.gateway.csrf_middleware import is_secure_request
|
||||||
|
from app.gateway.deps import get_current_user_from_request, get_local_provider
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request/Response Models ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class LoginResponse(BaseModel):
|
||||||
|
"""Response model for login — token only lives in HttpOnly cookie."""
|
||||||
|
|
||||||
|
expires_in: int # seconds
|
||||||
|
needs_setup: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class RegisterRequest(BaseModel):
|
||||||
|
"""Request model for user registration."""
|
||||||
|
|
||||||
|
email: EmailStr
|
||||||
|
password: str = Field(..., min_length=8)
|
||||||
|
|
||||||
|
|
||||||
|
class ChangePasswordRequest(BaseModel):
|
||||||
|
"""Request model for password change (also handles setup flow)."""
|
||||||
|
|
||||||
|
current_password: str
|
||||||
|
new_password: str = Field(..., min_length=8)
|
||||||
|
new_email: EmailStr | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class MessageResponse(BaseModel):
|
||||||
|
"""Generic message response."""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
||||||
|
"""Set the access_token HttpOnly cookie on the response."""
|
||||||
|
config = get_auth_config()
|
||||||
|
is_https = is_secure_request(request)
|
||||||
|
response.set_cookie(
|
||||||
|
key="access_token",
|
||||||
|
value=token,
|
||||||
|
httponly=True,
|
||||||
|
secure=is_https,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Rate Limiting ────────────────────────────────────────────────────────
|
||||||
|
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
|
||||||
|
|
||||||
|
_MAX_LOGIN_ATTEMPTS = 5
|
||||||
|
_LOCKOUT_SECONDS = 300 # 5 minutes
|
||||||
|
|
||||||
|
# ip → (fail_count, lock_until_timestamp)
|
||||||
|
_login_attempts: dict[str, tuple[int, float]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client_ip(request: Request) -> str:
|
||||||
|
"""Extract the real client IP for rate limiting.
|
||||||
|
|
||||||
|
Uses ``X-Real-IP`` header set by nginx (``proxy_set_header X-Real-IP
|
||||||
|
$remote_addr``). Nginx unconditionally overwrites any client-supplied
|
||||||
|
``X-Real-IP``, so the value seen by Gateway is always the TCP peer IP
|
||||||
|
that nginx observed — it cannot be spoofed by the client.
|
||||||
|
|
||||||
|
``request.client.host`` is NOT reliable because uvicorn's default
|
||||||
|
``proxy_headers=True`` replaces it with the *first* entry from
|
||||||
|
``X-Forwarded-For``, which IS client-spoofable.
|
||||||
|
|
||||||
|
``X-Forwarded-For`` is intentionally NOT used for the same reason.
|
||||||
|
"""
|
||||||
|
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||||
|
if real_ip:
|
||||||
|
return real_ip
|
||||||
|
|
||||||
|
# Fallback: direct connection without nginx (e.g. unit tests, dev).
|
||||||
|
return request.client.host if request.client else "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
def _check_rate_limit(ip: str) -> None:
|
||||||
|
"""Raise 429 if the IP is currently locked out."""
|
||||||
|
record = _login_attempts.get(ip)
|
||||||
|
if record is None:
|
||||||
|
return
|
||||||
|
fail_count, lock_until = record
|
||||||
|
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
||||||
|
if time.time() < lock_until:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=429,
|
||||||
|
detail="Too many login attempts. Try again later.",
|
||||||
|
)
|
||||||
|
del _login_attempts[ip]
|
||||||
|
|
||||||
|
|
||||||
|
_MAX_TRACKED_IPS = 10000
|
||||||
|
|
||||||
|
|
||||||
|
def _record_login_failure(ip: str) -> None:
|
||||||
|
"""Record a failed login attempt for the given IP."""
|
||||||
|
# Evict expired lockouts when dict grows too large
|
||||||
|
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||||
|
now = time.time()
|
||||||
|
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
||||||
|
for k in expired:
|
||||||
|
del _login_attempts[k]
|
||||||
|
# If still too large, evict cheapest-to-lose half: below-threshold
|
||||||
|
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
|
||||||
|
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||||
|
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
||||||
|
for k, _ in by_time[: len(by_time) // 2]:
|
||||||
|
del _login_attempts[k]
|
||||||
|
|
||||||
|
record = _login_attempts.get(ip)
|
||||||
|
if record is None:
|
||||||
|
_login_attempts[ip] = (1, 0.0)
|
||||||
|
else:
|
||||||
|
new_count = record[0] + 1
|
||||||
|
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
||||||
|
_login_attempts[ip] = (new_count, lock_until)
|
||||||
|
|
||||||
|
|
||||||
|
def _record_login_success(ip: str) -> None:
|
||||||
|
"""Clear failure counter for the given IP on successful login."""
|
||||||
|
_login_attempts.pop(ip, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login/local", response_model=LoginResponse)
|
||||||
|
async def login_local(
|
||||||
|
request: Request,
|
||||||
|
response: Response,
|
||||||
|
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||||
|
):
|
||||||
|
"""Local email/password login."""
|
||||||
|
client_ip = _get_client_ip(request)
|
||||||
|
_check_rate_limit(client_ip)
|
||||||
|
|
||||||
|
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
|
||||||
|
|
||||||
|
if user is None:
|
||||||
|
_record_login_failure(client_ip)
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
_record_login_success(client_ip)
|
||||||
|
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||||
|
_set_session_cookie(response, token, request)
|
||||||
|
|
||||||
|
return LoginResponse(
|
||||||
|
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
||||||
|
needs_setup=user.needs_setup,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(request: Request, response: Response, body: RegisterRequest):
|
||||||
|
"""Register a new user account (always 'user' role).
|
||||||
|
|
||||||
|
Admin is auto-created on first boot. This endpoint creates regular users.
|
||||||
|
Auto-login by setting the session cookie.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
|
||||||
|
)
|
||||||
|
|
||||||
|
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||||
|
_set_session_cookie(response, token, request)
|
||||||
|
|
||||||
|
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout", response_model=MessageResponse)
|
||||||
|
async def logout(request: Request, response: Response):
|
||||||
|
"""Logout current user by clearing the cookie."""
|
||||||
|
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
||||||
|
return MessageResponse(message="Successfully logged out")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/change-password", response_model=MessageResponse)
|
||||||
|
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
|
||||||
|
"""Change password for the currently authenticated user.
|
||||||
|
|
||||||
|
Also handles the first-boot setup flow:
|
||||||
|
- If new_email is provided, updates email (checks uniqueness)
|
||||||
|
- If user.needs_setup is True and new_email is given, clears needs_setup
|
||||||
|
- Always increments token_version to invalidate old sessions
|
||||||
|
- Re-issues session cookie with new token_version
|
||||||
|
"""
|
||||||
|
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||||
|
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
|
||||||
|
if user.password_hash is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
||||||
|
|
||||||
|
if not await verify_password_async(body.current_password, user.password_hash):
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
|
||||||
|
|
||||||
|
provider = get_local_provider()
|
||||||
|
|
||||||
|
# Update email if provided
|
||||||
|
if body.new_email is not None:
|
||||||
|
existing = await provider.get_user_by_email(body.new_email)
|
||||||
|
if existing and str(existing.id) != str(user.id):
|
||||||
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
|
||||||
|
user.email = body.new_email
|
||||||
|
|
||||||
|
# Update password + bump version
|
||||||
|
user.password_hash = await hash_password_async(body.new_password)
|
||||||
|
user.token_version += 1
|
||||||
|
|
||||||
|
# Clear setup flag if this is the setup flow
|
||||||
|
if user.needs_setup and body.new_email is not None:
|
||||||
|
user.needs_setup = False
|
||||||
|
|
||||||
|
await provider.update_user(user)
|
||||||
|
|
||||||
|
# Re-issue cookie with new token_version
|
||||||
|
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||||
|
_set_session_cookie(response, token, request)
|
||||||
|
|
||||||
|
return MessageResponse(message="Password changed successfully")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserResponse)
|
||||||
|
async def get_me(request: Request):
|
||||||
|
"""Get current authenticated user info."""
|
||||||
|
user = await get_current_user_from_request(request)
|
||||||
|
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/setup-status")
|
||||||
|
async def setup_status():
|
||||||
|
"""Check if admin account exists. Always False after first boot."""
|
||||||
|
user_count = await get_local_provider().count_users()
|
||||||
|
return {"needs_setup": user_count == 0}
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/oauth/{provider}")
|
||||||
|
async def oauth_login(provider: str):
|
||||||
|
"""Initiate OAuth login flow.
|
||||||
|
|
||||||
|
Redirects to the OAuth provider's authorization URL.
|
||||||
|
Currently a placeholder - requires OAuth provider implementation.
|
||||||
|
"""
|
||||||
|
if provider not in ["github", "google"]:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unsupported OAuth provider: {provider}",
|
||||||
|
)
|
||||||
|
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="OAuth login not yet implemented",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/callback/{provider}")
|
||||||
|
async def oauth_callback(provider: str, code: str, state: str):
|
||||||
|
"""OAuth callback endpoint.
|
||||||
|
|
||||||
|
Handles the OAuth provider's callback after user authorization.
|
||||||
|
Currently a placeholder.
|
||||||
|
"""
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
||||||
|
detail="OAuth callback not yet implemented",
|
||||||
|
)
|
||||||
@@ -8,6 +8,9 @@
|
|||||||
"graphs": {
|
"graphs": {
|
||||||
"lead_agent": "deerflow.agents:make_lead_agent"
|
"lead_agent": "deerflow.agents:make_lead_agent"
|
||||||
},
|
},
|
||||||
|
"auth": {
|
||||||
|
"path": "./app/gateway/langgraph_auth.py:auth"
|
||||||
|
},
|
||||||
"checkpointer": {
|
"checkpointer": {
|
||||||
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
|
"path": "./packages/harness/deerflow/agents/checkpointer/async_provider.py:make_checkpointer"
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,506 @@
|
|||||||
|
"""Tests for authentication module: JWT, password hashing, AuthContext, and authz decorators."""
|
||||||
|
|
||||||
|
from datetime import timedelta
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI, HTTPException
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from app.gateway.auth import create_access_token, decode_token, hash_password, verify_password
|
||||||
|
from app.gateway.auth.models import User
|
||||||
|
from app.gateway.authz import (
|
||||||
|
AuthContext,
|
||||||
|
Permissions,
|
||||||
|
get_auth_context,
|
||||||
|
require_auth,
|
||||||
|
require_permission,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Password Hashing ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_hash_password_and_verify():
|
||||||
|
"""Hashing and verification round-trip."""
|
||||||
|
password = "s3cr3tP@ssw0rd!"
|
||||||
|
hashed = hash_password(password)
|
||||||
|
assert hashed != password
|
||||||
|
assert verify_password(password, hashed) is True
|
||||||
|
assert verify_password("wrongpassword", hashed) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_hash_password_different_each_time():
|
||||||
|
"""bcrypt generates unique salts, so same password has different hashes."""
|
||||||
|
password = "testpassword"
|
||||||
|
h1 = hash_password(password)
|
||||||
|
h2 = hash_password(password)
|
||||||
|
assert h1 != h2 # Different salts
|
||||||
|
# But both verify correctly
|
||||||
|
assert verify_password(password, h1) is True
|
||||||
|
assert verify_password(password, h2) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_verify_password_rejects_empty():
|
||||||
|
"""Empty password should not verify."""
|
||||||
|
hashed = hash_password("nonempty")
|
||||||
|
assert verify_password("", hashed) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── JWT ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_and_decode_token():
|
||||||
|
"""JWT creation and decoding round-trip."""
|
||||||
|
user_id = str(uuid4())
|
||||||
|
# Set a valid JWT secret for this test
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||||
|
token = create_access_token(user_id)
|
||||||
|
assert isinstance(token, str)
|
||||||
|
|
||||||
|
payload = decode_token(token)
|
||||||
|
assert payload is not None
|
||||||
|
assert payload.sub == user_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_token_expired():
|
||||||
|
"""Expired token returns TokenError.EXPIRED."""
|
||||||
|
from app.gateway.auth.errors import TokenError
|
||||||
|
|
||||||
|
user_id = str(uuid4())
|
||||||
|
# Create token that expires immediately
|
||||||
|
token = create_access_token(user_id, expires_delta=timedelta(seconds=-1))
|
||||||
|
payload = decode_token(token)
|
||||||
|
assert payload == TokenError.EXPIRED
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_token_invalid():
|
||||||
|
"""Invalid token returns TokenError."""
|
||||||
|
from app.gateway.auth.errors import TokenError
|
||||||
|
|
||||||
|
assert isinstance(decode_token("not.a.valid.token"), TokenError)
|
||||||
|
assert isinstance(decode_token(""), TokenError)
|
||||||
|
assert isinstance(decode_token("completely-wrong"), TokenError)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_token_custom_expiry():
|
||||||
|
"""Custom expiry is respected."""
|
||||||
|
user_id = str(uuid4())
|
||||||
|
token = create_access_token(user_id, expires_delta=timedelta(hours=1))
|
||||||
|
payload = decode_token(token)
|
||||||
|
assert payload is not None
|
||||||
|
assert payload.sub == user_id
|
||||||
|
|
||||||
|
|
||||||
|
# ── AuthContext ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_context_unauthenticated():
|
||||||
|
"""AuthContext with no user."""
|
||||||
|
ctx = AuthContext(user=None, permissions=[])
|
||||||
|
assert ctx.is_authenticated is False
|
||||||
|
assert ctx.has_permission("threads", "read") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_context_authenticated_no_perms():
|
||||||
|
"""AuthContext with user but no permissions."""
|
||||||
|
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||||
|
ctx = AuthContext(user=user, permissions=[])
|
||||||
|
assert ctx.is_authenticated is True
|
||||||
|
assert ctx.has_permission("threads", "read") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_context_has_permission():
|
||||||
|
"""AuthContext permission checking."""
|
||||||
|
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||||
|
perms = [Permissions.THREADS_READ, Permissions.THREADS_WRITE]
|
||||||
|
ctx = AuthContext(user=user, permissions=perms)
|
||||||
|
assert ctx.has_permission("threads", "read") is True
|
||||||
|
assert ctx.has_permission("threads", "write") is True
|
||||||
|
assert ctx.has_permission("threads", "delete") is False
|
||||||
|
assert ctx.has_permission("runs", "read") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_context_require_user_raises():
|
||||||
|
"""require_user raises 401 when not authenticated."""
|
||||||
|
ctx = AuthContext(user=None, permissions=[])
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
ctx.require_user()
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_context_require_user_returns_user():
|
||||||
|
"""require_user returns user when authenticated."""
|
||||||
|
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||||
|
ctx = AuthContext(user=user, permissions=[])
|
||||||
|
returned = ctx.require_user()
|
||||||
|
assert returned == user
|
||||||
|
|
||||||
|
|
||||||
|
# ── get_auth_context helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_auth_context_not_set():
|
||||||
|
"""get_auth_context returns None when auth not set on request."""
|
||||||
|
mock_request = MagicMock()
|
||||||
|
# Make getattr return None (simulating attribute not set)
|
||||||
|
mock_request.state = MagicMock()
|
||||||
|
del mock_request.state.auth
|
||||||
|
assert get_auth_context(mock_request) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_auth_context_set():
|
||||||
|
"""get_auth_context returns the AuthContext from request."""
|
||||||
|
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||||
|
ctx = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.state.auth = ctx
|
||||||
|
|
||||||
|
assert get_auth_context(mock_request) == ctx
|
||||||
|
|
||||||
|
|
||||||
|
# ── require_auth decorator ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_auth_sets_auth_context():
|
||||||
|
"""require_auth sets auth context on request from cookie."""
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
@require_auth
|
||||||
|
async def endpoint(request: Request):
|
||||||
|
ctx = get_auth_context(request)
|
||||||
|
return {"authenticated": ctx.is_authenticated}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# No cookie → anonymous
|
||||||
|
response = client.get("/test")
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["authenticated"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_auth_requires_request_param():
|
||||||
|
"""require_auth raises ValueError if request parameter is missing."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
@require_auth
|
||||||
|
async def bad_endpoint(): # Missing `request` parameter
|
||||||
|
pass
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="require_auth decorator requires 'request' parameter"):
|
||||||
|
asyncio.run(bad_endpoint())
|
||||||
|
|
||||||
|
|
||||||
|
# ── require_permission decorator ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_permission_requires_auth():
|
||||||
|
"""require_permission raises 401 when not authenticated."""
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
@require_permission("threads", "read")
|
||||||
|
async def endpoint(request: Request):
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/test")
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Authentication required" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_require_permission_denies_wrong_permission():
|
||||||
|
"""User without required permission gets 403."""
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
user = User(id=uuid4(), email="test@example.com", password_hash="hash")
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
@require_permission("threads", "delete")
|
||||||
|
async def endpoint(request: Request):
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
mock_auth = AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||||
|
|
||||||
|
with patch("app.gateway.authz._authenticate", return_value=mock_auth):
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/test")
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert "Permission denied" in response.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
# ── User Model Fields ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_model_has_needs_setup_default_false():
|
||||||
|
"""New users default to needs_setup=False."""
|
||||||
|
user = User(email="test@example.com", password_hash="hash")
|
||||||
|
assert user.needs_setup is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_model_has_token_version_default_zero():
|
||||||
|
"""New users default to token_version=0."""
|
||||||
|
user = User(email="test@example.com", password_hash="hash")
|
||||||
|
assert user.token_version == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_model_needs_setup_true():
|
||||||
|
"""Auto-created admin has needs_setup=True."""
|
||||||
|
user = User(email="admin@example.com", password_hash="hash", needs_setup=True)
|
||||||
|
assert user.needs_setup is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_round_trip_new_fields():
|
||||||
|
"""needs_setup and token_version survive create → read round-trip."""
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from app.gateway.auth.repositories import sqlite as sqlite_mod
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = os.path.join(tmpdir, "test_users.db")
|
||||||
|
old_path = sqlite_mod._resolved_db_path
|
||||||
|
old_init = sqlite_mod._table_initialized
|
||||||
|
sqlite_mod._resolved_db_path = Path(db_path)
|
||||||
|
sqlite_mod._table_initialized = False
|
||||||
|
try:
|
||||||
|
repo = sqlite_mod.SQLiteUserRepository()
|
||||||
|
user = User(
|
||||||
|
email="setup@test.com",
|
||||||
|
password_hash="fakehash",
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=True,
|
||||||
|
token_version=3,
|
||||||
|
)
|
||||||
|
created = asyncio.run(repo.create_user(user))
|
||||||
|
assert created.needs_setup is True
|
||||||
|
assert created.token_version == 3
|
||||||
|
|
||||||
|
fetched = asyncio.run(repo.get_user_by_email("setup@test.com"))
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.needs_setup is True
|
||||||
|
assert fetched.token_version == 3
|
||||||
|
|
||||||
|
fetched.needs_setup = False
|
||||||
|
fetched.token_version = 4
|
||||||
|
asyncio.run(repo.update_user(fetched))
|
||||||
|
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id)))
|
||||||
|
assert refetched.needs_setup is False
|
||||||
|
assert refetched.token_version == 4
|
||||||
|
finally:
|
||||||
|
sqlite_mod._resolved_db_path = old_path
|
||||||
|
sqlite_mod._table_initialized = old_init
|
||||||
|
|
||||||
|
|
||||||
|
# ── Token Versioning ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_jwt_encodes_ver():
|
||||||
|
"""JWT payload includes ver field."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from app.gateway.auth.errors import TokenError
|
||||||
|
|
||||||
|
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||||
|
token = create_access_token(str(uuid4()), token_version=3)
|
||||||
|
payload = decode_token(token)
|
||||||
|
assert not isinstance(payload, TokenError)
|
||||||
|
assert payload.ver == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_jwt_default_ver_zero():
|
||||||
|
"""JWT ver defaults to 0."""
|
||||||
|
import os
|
||||||
|
|
||||||
|
from app.gateway.auth.errors import TokenError
|
||||||
|
|
||||||
|
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||||
|
token = create_access_token(str(uuid4()))
|
||||||
|
payload = decode_token(token)
|
||||||
|
assert not isinstance(payload, TokenError)
|
||||||
|
assert payload.ver == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_version_mismatch_rejects():
|
||||||
|
"""Token with stale ver is rejected by get_current_user_from_request."""
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["AUTH_JWT_SECRET"] = "test-secret-key-for-jwt-testing-minimum-32-chars"
|
||||||
|
|
||||||
|
user_id = str(uuid4())
|
||||||
|
token = create_access_token(user_id, token_version=0)
|
||||||
|
|
||||||
|
mock_user = User(id=user_id, email="test@example.com", password_hash="hash", token_version=1)
|
||||||
|
|
||||||
|
mock_request = MagicMock()
|
||||||
|
mock_request.cookies = {"access_token": token}
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider") as mock_provider_fn:
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.get_user = AsyncMock(return_value=mock_user)
|
||||||
|
mock_provider_fn.return_value = mock_provider
|
||||||
|
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(get_current_user_from_request(mock_request))
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
assert "revoked" in str(exc_info.value.detail).lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ── change-password extension ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_change_password_request_accepts_new_email():
|
||||||
|
"""ChangePasswordRequest model accepts optional new_email."""
|
||||||
|
from app.gateway.routers.auth import ChangePasswordRequest
|
||||||
|
|
||||||
|
req = ChangePasswordRequest(
|
||||||
|
current_password="old",
|
||||||
|
new_password="newpassword",
|
||||||
|
new_email="new@example.com",
|
||||||
|
)
|
||||||
|
assert req.new_email == "new@example.com"
|
||||||
|
|
||||||
|
|
||||||
|
def test_change_password_request_new_email_optional():
|
||||||
|
"""ChangePasswordRequest model works without new_email."""
|
||||||
|
from app.gateway.routers.auth import ChangePasswordRequest
|
||||||
|
|
||||||
|
req = ChangePasswordRequest(current_password="old", new_password="newpassword")
|
||||||
|
assert req.new_email is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_response_includes_needs_setup():
|
||||||
|
"""LoginResponse includes needs_setup field."""
|
||||||
|
from app.gateway.routers.auth import LoginResponse
|
||||||
|
|
||||||
|
resp = LoginResponse(expires_in=3600, needs_setup=True)
|
||||||
|
assert resp.needs_setup is True
|
||||||
|
resp2 = LoginResponse(expires_in=3600)
|
||||||
|
assert resp2.needs_setup is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Rate Limiting ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limiter_allows_under_limit():
|
||||||
|
"""Requests under the limit are allowed."""
|
||||||
|
from app.gateway.routers.auth import _check_rate_limit, _login_attempts
|
||||||
|
|
||||||
|
_login_attempts.clear()
|
||||||
|
_check_rate_limit("192.168.1.1") # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limiter_blocks_after_max_failures():
|
||||||
|
"""IP is blocked after 5 consecutive failures."""
|
||||||
|
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure
|
||||||
|
|
||||||
|
_login_attempts.clear()
|
||||||
|
ip = "10.0.0.1"
|
||||||
|
for _ in range(5):
|
||||||
|
_record_login_failure(ip)
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
_check_rate_limit(ip)
|
||||||
|
assert exc_info.value.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limiter_resets_on_success():
|
||||||
|
"""Successful login clears the failure counter."""
|
||||||
|
from app.gateway.routers.auth import _check_rate_limit, _login_attempts, _record_login_failure, _record_login_success
|
||||||
|
|
||||||
|
_login_attempts.clear()
|
||||||
|
ip = "10.0.0.2"
|
||||||
|
for _ in range(4):
|
||||||
|
_record_login_failure(ip)
|
||||||
|
_record_login_success(ip)
|
||||||
|
_check_rate_limit(ip) # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
# ── Client IP extraction ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_client_ip_direct_connection():
|
||||||
|
"""Without nginx (no X-Real-IP), falls back to request.client.host."""
|
||||||
|
from app.gateway.routers.auth import _get_client_ip
|
||||||
|
|
||||||
|
req = MagicMock()
|
||||||
|
req.client.host = "203.0.113.42"
|
||||||
|
req.headers = {}
|
||||||
|
assert _get_client_ip(req) == "203.0.113.42"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_client_ip_uses_x_real_ip():
|
||||||
|
"""X-Real-IP (set by nginx) is used when present."""
|
||||||
|
from app.gateway.routers.auth import _get_client_ip
|
||||||
|
|
||||||
|
req = MagicMock()
|
||||||
|
req.client.host = "10.0.0.1" # uvicorn may have replaced this with XFF[0]
|
||||||
|
req.headers = {"x-real-ip": "203.0.113.42"}
|
||||||
|
assert _get_client_ip(req) == "203.0.113.42"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_client_ip_xff_ignored():
|
||||||
|
"""X-Forwarded-For is never used; only X-Real-IP matters."""
|
||||||
|
from app.gateway.routers.auth import _get_client_ip
|
||||||
|
|
||||||
|
req = MagicMock()
|
||||||
|
req.client.host = "10.0.0.1"
|
||||||
|
req.headers = {"x-forwarded-for": "10.0.0.1, 198.51.100.5", "x-real-ip": "198.51.100.5"}
|
||||||
|
assert _get_client_ip(req) == "198.51.100.5"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_client_ip_no_real_ip_fallback():
|
||||||
|
"""No X-Real-IP → falls back to client.host (direct connection)."""
|
||||||
|
from app.gateway.routers.auth import _get_client_ip
|
||||||
|
|
||||||
|
req = MagicMock()
|
||||||
|
req.client.host = "127.0.0.1"
|
||||||
|
req.headers = {}
|
||||||
|
assert _get_client_ip(req) == "127.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_client_ip_x_real_ip_always_preferred():
|
||||||
|
"""X-Real-IP is always preferred over client.host regardless of IP."""
|
||||||
|
from app.gateway.routers.auth import _get_client_ip
|
||||||
|
|
||||||
|
req = MagicMock()
|
||||||
|
req.client.host = "203.0.113.99"
|
||||||
|
req.headers = {"x-real-ip": "198.51.100.7"}
|
||||||
|
assert _get_client_ip(req) == "198.51.100.7"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_jwt_secret_generates_ephemeral(monkeypatch, caplog):
|
||||||
|
"""get_auth_config() auto-generates an ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import app.gateway.auth.config as config_module
|
||||||
|
|
||||||
|
config_module._auth_config = None
|
||||||
|
monkeypatch.delenv("AUTH_JWT_SECRET", raising=False)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
config = config_module.get_auth_config()
|
||||||
|
|
||||||
|
assert config.jwt_secret # non-empty ephemeral secret
|
||||||
|
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
config_module._auth_config = None
|
||||||
@@ -0,0 +1,216 @@
|
|||||||
|
"""Tests for the global AuthMiddleware (fail-closed safety net)."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from app.gateway.auth_middleware import AuthMiddleware, _is_public
|
||||||
|
|
||||||
|
# ── _is_public unit tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"path",
|
||||||
|
[
|
||||||
|
"/health",
|
||||||
|
"/health/",
|
||||||
|
"/docs",
|
||||||
|
"/docs/",
|
||||||
|
"/redoc",
|
||||||
|
"/openapi.json",
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/logout",
|
||||||
|
"/api/v1/auth/setup-status",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_public_paths(path: str):
|
||||||
|
assert _is_public(path) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"path",
|
||||||
|
[
|
||||||
|
"/api/models",
|
||||||
|
"/api/mcp/config",
|
||||||
|
"/api/memory",
|
||||||
|
"/api/skills",
|
||||||
|
"/api/threads/123",
|
||||||
|
"/api/threads/123/uploads",
|
||||||
|
"/api/agents",
|
||||||
|
"/api/channels",
|
||||||
|
"/api/runs/stream",
|
||||||
|
"/api/threads/123/runs",
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
"/api/v1/auth/change-password",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_protected_paths(path: str):
|
||||||
|
assert _is_public(path) is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trailing slash / normalization edge cases ─────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"path",
|
||||||
|
[
|
||||||
|
"/api/v1/auth/login/local/",
|
||||||
|
"/api/v1/auth/register/",
|
||||||
|
"/api/v1/auth/logout/",
|
||||||
|
"/api/v1/auth/setup-status/",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_public_auth_paths_with_trailing_slash(path: str):
|
||||||
|
assert _is_public(path) is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"path",
|
||||||
|
[
|
||||||
|
"/api/models/",
|
||||||
|
"/api/v1/auth/me/",
|
||||||
|
"/api/v1/auth/change-password/",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_protected_paths_with_trailing_slash(path: str):
|
||||||
|
assert _is_public(path) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_api_path_is_protected():
|
||||||
|
"""Fail-closed: any new /api/* path is protected by default."""
|
||||||
|
assert _is_public("/api/new-feature") is False
|
||||||
|
assert _is_public("/api/v2/something") is False
|
||||||
|
assert _is_public("/api/v1/auth/new-endpoint") is False
|
||||||
|
|
||||||
|
|
||||||
|
# ── Middleware integration tests ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app():
|
||||||
|
"""Create a minimal FastAPI app with AuthMiddleware for testing."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
app.add_middleware(AuthMiddleware)
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
@app.get("/api/v1/auth/me")
|
||||||
|
async def auth_me():
|
||||||
|
return {"id": "1", "email": "test@test.com"}
|
||||||
|
|
||||||
|
@app.get("/api/v1/auth/setup-status")
|
||||||
|
async def setup_status():
|
||||||
|
return {"needs_setup": False}
|
||||||
|
|
||||||
|
@app.get("/api/models")
|
||||||
|
async def models_get():
|
||||||
|
return {"models": []}
|
||||||
|
|
||||||
|
@app.put("/api/mcp/config")
|
||||||
|
async def mcp_put():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.delete("/api/threads/abc")
|
||||||
|
async def thread_delete():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.patch("/api/threads/abc")
|
||||||
|
async def thread_patch():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.post("/api/threads/abc/runs/stream")
|
||||||
|
async def stream():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.get("/api/future-endpoint")
|
||||||
|
async def future():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
return TestClient(_make_app())
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_path_no_cookie(client):
|
||||||
|
res = client.get("/health")
|
||||||
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_public_auth_path_no_cookie(client):
|
||||||
|
"""Public auth endpoints (login/register) pass without cookie."""
|
||||||
|
res = client.get("/api/v1/auth/setup-status")
|
||||||
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_protected_auth_path_no_cookie(client):
|
||||||
|
"""/auth/me requires cookie even though it's under /api/v1/auth/."""
|
||||||
|
res = client.get("/api/v1/auth/me")
|
||||||
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_protected_path_no_cookie_returns_401(client):
|
||||||
|
res = client.get("/api/models")
|
||||||
|
assert res.status_code == 401
|
||||||
|
body = res.json()
|
||||||
|
assert body["detail"]["code"] == "not_authenticated"
|
||||||
|
|
||||||
|
|
||||||
|
def test_protected_path_with_cookie_passes(client):
|
||||||
|
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||||
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_protected_post_no_cookie_returns_401(client):
|
||||||
|
res = client.post("/api/threads/abc/runs/stream")
|
||||||
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── Method matrix: PUT/DELETE/PATCH also protected ────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_protected_put_no_cookie(client):
|
||||||
|
res = client.put("/api/mcp/config")
|
||||||
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_protected_delete_no_cookie(client):
|
||||||
|
res = client.delete("/api/threads/abc")
|
||||||
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_protected_patch_no_cookie(client):
|
||||||
|
res = client.patch("/api/threads/abc")
|
||||||
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_put_with_cookie_passes(client):
|
||||||
|
client.cookies.set("access_token", "tok")
|
||||||
|
res = client.put("/api/mcp/config")
|
||||||
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_with_cookie_passes(client):
|
||||||
|
client.cookies.set("access_token", "tok")
|
||||||
|
res = client.delete("/api/threads/abc")
|
||||||
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fail-closed: unknown future endpoints ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_endpoint_no_cookie_returns_401(client):
|
||||||
|
"""Any new /api/* endpoint is blocked by default without cookie."""
|
||||||
|
res = client.get("/api/future-endpoint")
|
||||||
|
assert res.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_endpoint_with_cookie_passes(client):
|
||||||
|
client.cookies.set("access_token", "tok")
|
||||||
|
res = client.get("/api/future-endpoint")
|
||||||
|
assert res.status_code == 200
|
||||||
@@ -0,0 +1,675 @@
|
|||||||
|
"""Tests for auth type system hardening.
|
||||||
|
|
||||||
|
Covers structured error responses, typed decode_token callers,
|
||||||
|
CSRF middleware path matching, config-driven cookie security,
|
||||||
|
and unhappy paths / edge cases for all auth boundaries.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import secrets
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
|
import pytest
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||||
|
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||||
|
from app.gateway.auth.jwt import decode_token
|
||||||
|
from app.gateway.csrf_middleware import (
|
||||||
|
CSRF_COOKIE_NAME,
|
||||||
|
CSRF_HEADER_NAME,
|
||||||
|
CSRFMiddleware,
|
||||||
|
is_auth_endpoint,
|
||||||
|
should_check_csrf,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Setup ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_config():
|
||||||
|
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||||
|
|
||||||
|
|
||||||
|
# ── CSRF Middleware Path Matching ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeRequest:
|
||||||
|
"""Minimal request mock for CSRF path matching tests."""
|
||||||
|
|
||||||
|
def __init__(self, path: str, method: str = "POST"):
|
||||||
|
self.method = method
|
||||||
|
|
||||||
|
class _URL:
|
||||||
|
def __init__(self, p):
|
||||||
|
self.path = p
|
||||||
|
|
||||||
|
self.url = _URL(path)
|
||||||
|
self.cookies = {}
|
||||||
|
self.headers = {}
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_exempts_login_local():
|
||||||
|
"""login/local (actual route) should be exempt from CSRF."""
|
||||||
|
req = _FakeRequest("/api/v1/auth/login/local")
|
||||||
|
assert is_auth_endpoint(req) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_exempts_login_local_trailing_slash():
|
||||||
|
"""Trailing slash should also be exempt."""
|
||||||
|
req = _FakeRequest("/api/v1/auth/login/local/")
|
||||||
|
assert is_auth_endpoint(req) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_exempts_logout():
|
||||||
|
req = _FakeRequest("/api/v1/auth/logout")
|
||||||
|
assert is_auth_endpoint(req) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_exempts_register():
|
||||||
|
req = _FakeRequest("/api/v1/auth/register")
|
||||||
|
assert is_auth_endpoint(req) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_does_not_exempt_old_login_path():
|
||||||
|
"""Old /api/v1/auth/login (without /local) should NOT be exempt."""
|
||||||
|
req = _FakeRequest("/api/v1/auth/login")
|
||||||
|
assert is_auth_endpoint(req) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_does_not_exempt_me():
|
||||||
|
req = _FakeRequest("/api/v1/auth/me")
|
||||||
|
assert is_auth_endpoint(req) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_skips_get_requests():
|
||||||
|
req = _FakeRequest("/api/v1/auth/me", method="GET")
|
||||||
|
assert should_check_csrf(req) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_checks_post_to_protected():
|
||||||
|
req = _FakeRequest("/api/v1/some/endpoint", method="POST")
|
||||||
|
assert should_check_csrf(req) is True
|
||||||
|
|
||||||
|
|
||||||
|
# ── Structured Error Response Format ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_error_response_has_code_and_message():
|
||||||
|
"""All auth errors should have structured {code, message} format."""
|
||||||
|
err = AuthErrorResponse(
|
||||||
|
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||||
|
message="Wrong password",
|
||||||
|
)
|
||||||
|
d = err.model_dump()
|
||||||
|
assert "code" in d
|
||||||
|
assert "message" in d
|
||||||
|
assert d["code"] == "invalid_credentials"
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_error_response_all_codes_serializable():
|
||||||
|
"""Every AuthErrorCode should be serializable in AuthErrorResponse."""
|
||||||
|
for code in AuthErrorCode:
|
||||||
|
err = AuthErrorResponse(code=code, message=f"Test {code.value}")
|
||||||
|
d = err.model_dump()
|
||||||
|
assert d["code"] == code.value
|
||||||
|
|
||||||
|
|
||||||
|
# ── decode_token Caller Pattern ──────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_token_expired_maps_to_token_expired_code():
|
||||||
|
"""TokenError.EXPIRED should map to AuthErrorCode.TOKEN_EXPIRED."""
|
||||||
|
_setup_config()
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
|
|
||||||
|
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||||
|
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||||
|
result = decode_token(token)
|
||||||
|
assert result == TokenError.EXPIRED
|
||||||
|
|
||||||
|
# Verify the mapping pattern used in route handlers
|
||||||
|
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||||
|
assert code == AuthErrorCode.TOKEN_EXPIRED
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_token_invalid_sig_maps_to_token_invalid_code():
|
||||||
|
"""TokenError.INVALID_SIGNATURE should map to AuthErrorCode.TOKEN_INVALID."""
|
||||||
|
_setup_config()
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
|
||||||
|
import jwt as pyjwt
|
||||||
|
|
||||||
|
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||||
|
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||||
|
result = decode_token(token)
|
||||||
|
assert result == TokenError.INVALID_SIGNATURE
|
||||||
|
|
||||||
|
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||||
|
assert code == AuthErrorCode.TOKEN_INVALID
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_token_malformed_maps_to_token_invalid_code():
|
||||||
|
"""TokenError.MALFORMED should map to AuthErrorCode.TOKEN_INVALID."""
|
||||||
|
_setup_config()
|
||||||
|
result = decode_token("garbage")
|
||||||
|
assert result == TokenError.MALFORMED
|
||||||
|
|
||||||
|
code = AuthErrorCode.TOKEN_EXPIRED if result == TokenError.EXPIRED else AuthErrorCode.TOKEN_INVALID
|
||||||
|
assert code == AuthErrorCode.TOKEN_INVALID
|
||||||
|
|
||||||
|
|
||||||
|
# ── Login Response Format ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_response_model_has_no_access_token():
|
||||||
|
"""LoginResponse should NOT contain access_token field (RFC-001)."""
|
||||||
|
from app.gateway.routers.auth import LoginResponse
|
||||||
|
|
||||||
|
resp = LoginResponse(expires_in=604800)
|
||||||
|
d = resp.model_dump()
|
||||||
|
assert "access_token" not in d
|
||||||
|
assert "expires_in" in d
|
||||||
|
assert d["expires_in"] == 604800
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_response_model_fields():
|
||||||
|
"""LoginResponse has expires_in and needs_setup."""
|
||||||
|
from app.gateway.routers.auth import LoginResponse
|
||||||
|
|
||||||
|
fields = set(LoginResponse.model_fields.keys())
|
||||||
|
assert fields == {"expires_in", "needs_setup"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── AuthConfig in Route ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_config_token_expiry_used_in_login_response():
|
||||||
|
"""LoginResponse.expires_in should come from config.token_expiry_days."""
|
||||||
|
from app.gateway.routers.auth import LoginResponse
|
||||||
|
|
||||||
|
expected_seconds = 14 * 24 * 3600
|
||||||
|
resp = LoginResponse(expires_in=expected_seconds)
|
||||||
|
assert resp.expires_in == expected_seconds
|
||||||
|
|
||||||
|
|
||||||
|
# ── UserResponse Type Preservation ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_response_system_role_literal():
|
||||||
|
"""UserResponse.system_role should only accept 'admin' or 'user'."""
|
||||||
|
from app.gateway.auth.models import UserResponse
|
||||||
|
|
||||||
|
# Valid roles
|
||||||
|
resp = UserResponse(id="1", email="a@b.com", system_role="admin")
|
||||||
|
assert resp.system_role == "admin"
|
||||||
|
|
||||||
|
resp = UserResponse(id="1", email="a@b.com", system_role="user")
|
||||||
|
assert resp.system_role == "user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_response_rejects_invalid_role():
|
||||||
|
"""UserResponse should reject invalid system_role values."""
|
||||||
|
from app.gateway.auth.models import UserResponse
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
UserResponse(id="1", email="a@b.com", system_role="superadmin")
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════
|
||||||
|
# UNHAPPY PATHS / EDGE CASES
|
||||||
|
# ══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
# ── get_current_user structured 401 responses ────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_current_user_no_cookie_returns_not_authenticated():
|
||||||
|
"""No cookie → 401 with code=not_authenticated."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
mock_request = type("MockRequest", (), {"cookies": {}})()
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(get_current_user_from_request(mock_request))
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
detail = exc_info.value.detail
|
||||||
|
assert detail["code"] == "not_authenticated"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_current_user_expired_token_returns_token_expired():
|
||||||
|
"""Expired token → 401 with code=token_expired."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
_setup_config()
|
||||||
|
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||||
|
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||||
|
|
||||||
|
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(get_current_user_from_request(mock_request))
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
detail = exc_info.value.detail
|
||||||
|
assert detail["code"] == "token_expired"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_current_user_invalid_token_returns_token_invalid():
|
||||||
|
"""Bad signature → 401 with code=token_invalid."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
_setup_config()
|
||||||
|
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||||
|
token = pyjwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||||
|
|
||||||
|
mock_request = type("MockRequest", (), {"cookies": {"access_token": token}})()
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(get_current_user_from_request(mock_request))
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
detail = exc_info.value.detail
|
||||||
|
assert detail["code"] == "token_invalid"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_current_user_malformed_token_returns_token_invalid():
|
||||||
|
"""Garbage token → 401 with code=token_invalid."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
from app.gateway.deps import get_current_user_from_request
|
||||||
|
|
||||||
|
_setup_config()
|
||||||
|
mock_request = type("MockRequest", (), {"cookies": {"access_token": "not-a-jwt"}})()
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
asyncio.run(get_current_user_from_request(mock_request))
|
||||||
|
assert exc_info.value.status_code == 401
|
||||||
|
detail = exc_info.value.detail
|
||||||
|
assert detail["code"] == "token_invalid"
|
||||||
|
|
||||||
|
|
||||||
|
# ── decode_token edge cases ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_token_empty_string_returns_malformed():
|
||||||
|
_setup_config()
|
||||||
|
result = decode_token("")
|
||||||
|
assert result == TokenError.MALFORMED
|
||||||
|
|
||||||
|
|
||||||
|
def test_decode_token_whitespace_returns_malformed():
|
||||||
|
_setup_config()
|
||||||
|
result = decode_token(" ")
|
||||||
|
assert result == TokenError.MALFORMED
|
||||||
|
|
||||||
|
|
||||||
|
# ── AuthConfig validation edge cases ─────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_config_missing_jwt_secret_raises():
|
||||||
|
"""AuthConfig requires jwt_secret — no default allowed."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AuthConfig()
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_config_token_expiry_zero_raises():
|
||||||
|
"""token_expiry_days must be >= 1."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AuthConfig(jwt_secret="secret", token_expiry_days=0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_config_token_expiry_31_raises():
|
||||||
|
"""token_expiry_days must be <= 30."""
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
AuthConfig(jwt_secret="secret", token_expiry_days=31)
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_config_token_expiry_boundary_1_ok():
|
||||||
|
config = AuthConfig(jwt_secret="secret", token_expiry_days=1)
|
||||||
|
assert config.token_expiry_days == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_config_token_expiry_boundary_30_ok():
|
||||||
|
config = AuthConfig(jwt_secret="secret", token_expiry_days=30)
|
||||||
|
assert config.token_expiry_days == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_auth_config_missing_env_var_generates_ephemeral(caplog):
|
||||||
|
"""get_auth_config() auto-generates ephemeral secret when AUTH_JWT_SECRET is unset."""
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import app.gateway.auth.config as cfg
|
||||||
|
|
||||||
|
old = cfg._auth_config
|
||||||
|
cfg._auth_config = None
|
||||||
|
try:
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||||
|
with caplog.at_level(logging.WARNING):
|
||||||
|
config = cfg.get_auth_config()
|
||||||
|
assert config.jwt_secret
|
||||||
|
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||||
|
finally:
|
||||||
|
cfg._auth_config = old
|
||||||
|
|
||||||
|
|
||||||
|
# ── CSRF middleware integration (unhappy paths) ──────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_csrf_app():
|
||||||
|
"""Create a minimal FastAPI app with CSRFMiddleware for testing."""
|
||||||
|
from fastapi import HTTPException as _HTTPException
|
||||||
|
from fastapi.responses import JSONResponse as _JSONResponse
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
@app.exception_handler(_HTTPException)
|
||||||
|
async def _http_exc_handler(request, exc):
|
||||||
|
return _JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||||
|
|
||||||
|
app.add_middleware(CSRFMiddleware)
|
||||||
|
|
||||||
|
@app.post("/api/v1/test/protected")
|
||||||
|
async def protected():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.post("/api/v1/auth/login/local")
|
||||||
|
async def login():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
@app.get("/api/v1/test/read")
|
||||||
|
async def read_endpoint():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_middleware_blocks_post_without_token():
|
||||||
|
"""POST to protected endpoint without CSRF token → 403 with structured detail."""
|
||||||
|
client = TestClient(_make_csrf_app())
|
||||||
|
resp = client.post("/api/v1/test/protected")
|
||||||
|
assert resp.status_code == 403
|
||||||
|
assert "CSRF" in resp.json()["detail"]
|
||||||
|
assert "missing" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_middleware_blocks_post_with_mismatched_token():
|
||||||
|
"""POST with mismatched CSRF cookie/header → 403 with mismatch detail."""
|
||||||
|
client = TestClient(_make_csrf_app())
|
||||||
|
client.cookies.set(CSRF_COOKIE_NAME, "token-a")
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/test/protected",
|
||||||
|
headers={CSRF_HEADER_NAME: "token-b"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
assert "mismatch" in resp.json()["detail"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_middleware_allows_post_with_matching_token():
|
||||||
|
"""POST with matching CSRF cookie/header → 200."""
|
||||||
|
client = TestClient(_make_csrf_app())
|
||||||
|
token = secrets.token_urlsafe(64)
|
||||||
|
client.cookies.set(CSRF_COOKIE_NAME, token)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/test/protected",
|
||||||
|
headers={CSRF_HEADER_NAME: token},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_middleware_allows_get_without_token():
|
||||||
|
"""GET requests bypass CSRF check."""
|
||||||
|
client = TestClient(_make_csrf_app())
|
||||||
|
resp = client.get("/api/v1/test/read")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_middleware_exempts_login_local():
|
||||||
|
"""POST to login/local is exempt from CSRF (no token yet)."""
|
||||||
|
client = TestClient(_make_csrf_app())
|
||||||
|
resp = client.post("/api/v1/auth/login/local")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_middleware_sets_cookie_on_auth_endpoint():
|
||||||
|
"""Auth endpoints should receive a CSRF cookie in response."""
|
||||||
|
client = TestClient(_make_csrf_app())
|
||||||
|
resp = client.post("/api/v1/auth/login/local")
|
||||||
|
assert CSRF_COOKIE_NAME in resp.cookies
|
||||||
|
|
||||||
|
|
||||||
|
# ── UserResponse edge cases ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_response_missing_required_fields():
|
||||||
|
"""UserResponse with missing fields → ValidationError."""
|
||||||
|
from app.gateway.auth.models import UserResponse
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
UserResponse(id="1") # missing email, system_role
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
UserResponse(id="1", email="a@b.com") # missing system_role
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_response_empty_string_role_rejected():
|
||||||
|
"""Empty string is not a valid role."""
|
||||||
|
from app.gateway.auth.models import UserResponse
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
UserResponse(id="1", email="a@b.com", system_role="")
|
||||||
|
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════════════════
|
||||||
|
# HTTP-LEVEL API CONTRACT TESTS
|
||||||
|
# ══════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
|
||||||
|
def _make_auth_app():
|
||||||
|
"""Create FastAPI app with auth routes for contract testing."""
|
||||||
|
from app.gateway.app import create_app
|
||||||
|
|
||||||
|
return create_app()
|
||||||
|
|
||||||
|
|
||||||
|
def _get_auth_client():
|
||||||
|
"""Get TestClient for auth API contract tests."""
|
||||||
|
return TestClient(_make_auth_app())
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_auth_me_no_cookie_returns_structured_401():
|
||||||
|
"""/api/v1/auth/me without cookie → 401 with {code: 'not_authenticated'}."""
|
||||||
|
_setup_config()
|
||||||
|
client = _get_auth_client()
|
||||||
|
resp = client.get("/api/v1/auth/me")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
body = resp.json()
|
||||||
|
assert body["detail"]["code"] == "not_authenticated"
|
||||||
|
assert "message" in body["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_auth_me_expired_token_returns_structured_401():
|
||||||
|
"""/api/v1/auth/me with expired token → 401 with {code: 'token_expired'}."""
|
||||||
|
_setup_config()
|
||||||
|
expired = {"sub": "u1", "exp": datetime.now(UTC) - timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||||
|
token = pyjwt.encode(expired, _TEST_SECRET, algorithm="HS256")
|
||||||
|
|
||||||
|
client = _get_auth_client()
|
||||||
|
client.cookies.set("access_token", token)
|
||||||
|
resp = client.get("/api/v1/auth/me")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
body = resp.json()
|
||||||
|
assert body["detail"]["code"] == "token_expired"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_auth_me_invalid_sig_returns_structured_401():
|
||||||
|
"""/api/v1/auth/me with bad signature → 401 with {code: 'token_invalid'}."""
|
||||||
|
_setup_config()
|
||||||
|
payload = {"sub": "u1", "exp": datetime.now(UTC) + timedelta(hours=1), "iat": datetime.now(UTC)}
|
||||||
|
token = pyjwt.encode(payload, "wrong-key", algorithm="HS256")
|
||||||
|
|
||||||
|
client = _get_auth_client()
|
||||||
|
client.cookies.set("access_token", token)
|
||||||
|
resp = client.get("/api/v1/auth/me")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
body = resp.json()
|
||||||
|
assert body["detail"]["code"] == "token_invalid"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_login_bad_credentials_returns_structured_401():
|
||||||
|
"""Login with wrong password → 401 with {code: 'invalid_credentials'}."""
|
||||||
|
_setup_config()
|
||||||
|
client = _get_auth_client()
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
data={"username": "nonexistent@test.com", "password": "wrongpassword"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
body = resp.json()
|
||||||
|
assert body["detail"]["code"] == "invalid_credentials"
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_login_success_no_token_in_body():
|
||||||
|
"""Successful login → response body has expires_in but NOT access_token."""
|
||||||
|
_setup_config()
|
||||||
|
client = _get_auth_client()
|
||||||
|
# Register first
|
||||||
|
client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "contract-test@test.com", "password": "securepassword123"},
|
||||||
|
)
|
||||||
|
# Login
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
data={"username": "contract-test@test.com", "password": "securepassword123"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert "expires_in" in body
|
||||||
|
assert "access_token" not in body
|
||||||
|
# Token should be in cookie, not body
|
||||||
|
assert "access_token" in resp.cookies
|
||||||
|
|
||||||
|
|
||||||
|
def test_api_register_duplicate_returns_structured_400():
|
||||||
|
"""Register with duplicate email → 400 with {code: 'email_already_exists'}."""
|
||||||
|
_setup_config()
|
||||||
|
client = _get_auth_client()
|
||||||
|
email = "dup-contract-test@test.com"
|
||||||
|
# First register
|
||||||
|
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||||
|
# Duplicate
|
||||||
|
resp = client.post("/api/v1/auth/register", json={"email": email, "password": "password456"})
|
||||||
|
assert resp.status_code == 400
|
||||||
|
body = resp.json()
|
||||||
|
assert body["detail"]["code"] == "email_already_exists"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cookie security: HTTP vs HTTPS ────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _unique_email(prefix: str) -> str:
|
||||||
|
return f"{prefix}-{secrets.token_hex(4)}@test.com"
|
||||||
|
|
||||||
|
|
||||||
|
def _get_set_cookie_headers(resp) -> list[str]:
|
||||||
|
"""Extract all set-cookie header values from a TestClient response."""
|
||||||
|
return [v for k, v in resp.headers.multi_items() if k.lower() == "set-cookie"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_http_cookie_httponly_true_secure_false():
|
||||||
|
"""HTTP register → access_token cookie is httponly=True, secure=False, no max_age."""
|
||||||
|
_setup_config()
|
||||||
|
client = _get_auth_client()
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": _unique_email("http-cookie"), "password": "password123"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
cookie_header = resp.headers.get("set-cookie", "")
|
||||||
|
assert "access_token=" in cookie_header
|
||||||
|
assert "httponly" in cookie_header.lower()
|
||||||
|
assert "secure" not in cookie_header.lower().replace("samesite", "")
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_https_cookie_httponly_true_secure_true():
|
||||||
|
"""HTTPS register (x-forwarded-proto) → access_token cookie is httponly=True, secure=True, has max_age."""
|
||||||
|
_setup_config()
|
||||||
|
client = _get_auth_client()
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": _unique_email("https-cookie"), "password": "password123"},
|
||||||
|
headers={"x-forwarded-proto": "https"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
cookie_header = resp.headers.get("set-cookie", "")
|
||||||
|
assert "access_token=" in cookie_header
|
||||||
|
assert "httponly" in cookie_header.lower()
|
||||||
|
assert "secure" in cookie_header.lower()
|
||||||
|
assert "max-age" in cookie_header.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_login_https_sets_secure_cookie():
|
||||||
|
"""HTTPS login → access_token cookie has secure flag."""
|
||||||
|
_setup_config()
|
||||||
|
client = _get_auth_client()
|
||||||
|
email = _unique_email("https-login")
|
||||||
|
client.post("/api/v1/auth/register", json={"email": email, "password": "password123"})
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login/local",
|
||||||
|
data={"username": email, "password": "password123"},
|
||||||
|
headers={"x-forwarded-proto": "https"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
cookie_header = resp.headers.get("set-cookie", "")
|
||||||
|
assert "access_token=" in cookie_header
|
||||||
|
assert "httponly" in cookie_header.lower()
|
||||||
|
assert "secure" in cookie_header.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_cookie_secure_on_https():
|
||||||
|
"""HTTPS register → csrf_token cookie has secure flag but NOT httponly."""
|
||||||
|
_setup_config()
|
||||||
|
client = _get_auth_client()
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": _unique_email("csrf-https"), "password": "password123"},
|
||||||
|
headers={"x-forwarded-proto": "https"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||||
|
assert csrf_cookies, "csrf_token cookie not set on HTTPS register"
|
||||||
|
csrf_header = csrf_cookies[0]
|
||||||
|
assert "secure" in csrf_header.lower()
|
||||||
|
assert "httponly" not in csrf_header.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_cookie_not_secure_on_http():
|
||||||
|
"""HTTP register → csrf_token cookie does NOT have secure flag."""
|
||||||
|
_setup_config()
|
||||||
|
client = _get_auth_client()
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": _unique_email("csrf-http"), "password": "password123"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
csrf_cookies = [h for h in _get_set_cookie_headers(resp) if "csrf_token=" in h]
|
||||||
|
assert csrf_cookies, "csrf_token cookie not set on HTTP register"
|
||||||
|
csrf_header = csrf_cookies[0]
|
||||||
|
assert "secure" not in csrf_header.lower().replace("samesite", "")
|
||||||
@@ -0,0 +1,214 @@
|
|||||||
|
"""Tests for _ensure_admin_user() in app.py.
|
||||||
|
|
||||||
|
Covers: first-boot admin creation, auto-reset on needs_setup=True,
|
||||||
|
no-op on needs_setup=False, migration, and edge cases.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-ensure-admin-testing-min-32")
|
||||||
|
|
||||||
|
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||||
|
from app.gateway.auth.models import User
|
||||||
|
|
||||||
|
_JWT_SECRET = "test-secret-key-ensure-admin-testing-min-32"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _setup_auth_config():
|
||||||
|
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||||
|
yield
|
||||||
|
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||||
|
|
||||||
|
|
||||||
|
def _make_app_stub(store=None):
|
||||||
|
"""Minimal app-like object with state.store."""
|
||||||
|
app = SimpleNamespace()
|
||||||
|
app.state = SimpleNamespace()
|
||||||
|
app.state.store = store
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
def _make_provider(user_count=0, admin_user=None):
|
||||||
|
p = AsyncMock()
|
||||||
|
p.count_users = AsyncMock(return_value=user_count)
|
||||||
|
p.create_user = AsyncMock(
|
||||||
|
side_effect=lambda **kw: User(
|
||||||
|
email=kw["email"],
|
||||||
|
password_hash="hashed",
|
||||||
|
system_role=kw.get("system_role", "user"),
|
||||||
|
needs_setup=kw.get("needs_setup", False),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
p.get_user_by_email = AsyncMock(return_value=admin_user)
|
||||||
|
p.update_user = AsyncMock(side_effect=lambda u: u)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
# ── First boot: no users ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_first_boot_creates_admin():
|
||||||
|
"""count_users==0 → create admin with needs_setup=True."""
|
||||||
|
provider = _make_provider(user_count=0)
|
||||||
|
app = _make_app_stub()
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
provider.create_user.assert_called_once()
|
||||||
|
call_kwargs = provider.create_user.call_args[1]
|
||||||
|
assert call_kwargs["email"] == "admin@deerflow.dev"
|
||||||
|
assert call_kwargs["system_role"] == "admin"
|
||||||
|
assert call_kwargs["needs_setup"] is True
|
||||||
|
assert len(call_kwargs["password"]) > 10 # random password generated
|
||||||
|
|
||||||
|
|
||||||
|
def test_first_boot_triggers_migration_if_store_present():
|
||||||
|
"""First boot with store → _migrate_orphaned_threads called."""
|
||||||
|
provider = _make_provider(user_count=0)
|
||||||
|
store = AsyncMock()
|
||||||
|
store.asearch = AsyncMock(return_value=[])
|
||||||
|
app = _make_app_stub(store=store)
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
store.asearch.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_first_boot_no_store_skips_migration():
|
||||||
|
"""First boot without store → no crash, migration skipped."""
|
||||||
|
provider = _make_provider(user_count=0)
|
||||||
|
app = _make_app_stub(store=None)
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
provider.create_user.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Subsequent boot: needs_setup=True → auto-reset ───────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_needs_setup_true_resets_password():
|
||||||
|
"""Existing admin with needs_setup=True → password reset + token_version bumped."""
|
||||||
|
admin = User(
|
||||||
|
email="admin@deerflow.dev",
|
||||||
|
password_hash="old-hash",
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=True,
|
||||||
|
token_version=0,
|
||||||
|
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||||
|
)
|
||||||
|
provider = _make_provider(user_count=1, admin_user=admin)
|
||||||
|
app = _make_app_stub()
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
# Password was reset
|
||||||
|
provider.update_user.assert_called_once()
|
||||||
|
updated = provider.update_user.call_args[0][0]
|
||||||
|
assert updated.password_hash == "new-hash"
|
||||||
|
assert updated.token_version == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_needs_setup_true_consecutive_resets_increment_version():
|
||||||
|
"""Two boots with needs_setup=True → token_version increments each time."""
|
||||||
|
admin = User(
|
||||||
|
email="admin@deerflow.dev",
|
||||||
|
password_hash="hash",
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=True,
|
||||||
|
token_version=3,
|
||||||
|
created_at=datetime.now(UTC) - timedelta(seconds=30),
|
||||||
|
)
|
||||||
|
provider = _make_provider(user_count=1, admin_user=admin)
|
||||||
|
app = _make_app_stub()
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="new-hash"):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
updated = provider.update_user.call_args[0][0]
|
||||||
|
assert updated.token_version == 4
|
||||||
|
|
||||||
|
|
||||||
|
# ── Subsequent boot: needs_setup=False → no-op ──────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_needs_setup_false_no_reset():
|
||||||
|
"""Admin with needs_setup=False → no password reset, no update."""
|
||||||
|
admin = User(
|
||||||
|
email="admin@deerflow.dev",
|
||||||
|
password_hash="stable-hash",
|
||||||
|
system_role="admin",
|
||||||
|
needs_setup=False,
|
||||||
|
token_version=2,
|
||||||
|
)
|
||||||
|
provider = _make_provider(user_count=1, admin_user=admin)
|
||||||
|
app = _make_app_stub()
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
provider.update_user.assert_not_called()
|
||||||
|
assert admin.password_hash == "stable-hash"
|
||||||
|
assert admin.token_version == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── Edge cases ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_admin_email_found_no_crash():
|
||||||
|
"""Users exist but no admin@deerflow.dev → no crash, no reset."""
|
||||||
|
provider = _make_provider(user_count=3, admin_user=None)
|
||||||
|
app = _make_app_stub()
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
provider.update_user.assert_not_called()
|
||||||
|
provider.create_user.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
def test_migration_failure_is_non_fatal():
|
||||||
|
"""_migrate_orphaned_threads exception is caught and logged."""
|
||||||
|
provider = _make_provider(user_count=0)
|
||||||
|
store = AsyncMock()
|
||||||
|
store.asearch = AsyncMock(side_effect=RuntimeError("store crashed"))
|
||||||
|
app = _make_app_stub(store=store)
|
||||||
|
|
||||||
|
with patch("app.gateway.deps.get_local_provider", return_value=provider):
|
||||||
|
with patch("app.gateway.auth.password.hash_password_async", new_callable=AsyncMock, return_value="hashed"):
|
||||||
|
from app.gateway.app import _ensure_admin_user
|
||||||
|
|
||||||
|
# Should not raise
|
||||||
|
asyncio.run(_ensure_admin_user(app))
|
||||||
|
|
||||||
|
provider.create_user.assert_called_once()
|
||||||
@@ -0,0 +1,312 @@
|
|||||||
|
"""Tests for LangGraph Server auth handler (langgraph_auth.py).
|
||||||
|
|
||||||
|
Validates that the LangGraph auth layer enforces the same rules as Gateway:
|
||||||
|
cookie → JWT decode → DB lookup → token_version check → owner filter
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from datetime import timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
os.environ.setdefault("AUTH_JWT_SECRET", "test-secret-key-for-langgraph-auth-testing-min-32")
|
||||||
|
|
||||||
|
from langgraph_sdk import Auth
|
||||||
|
|
||||||
|
from app.gateway.auth.config import AuthConfig, set_auth_config
|
||||||
|
from app.gateway.auth.jwt import create_access_token, decode_token
|
||||||
|
from app.gateway.auth.models import User
|
||||||
|
from app.gateway.langgraph_auth import add_owner_filter, authenticate
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_JWT_SECRET = "test-secret-key-for-langgraph-auth-testing-min-32"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _setup_auth_config():
|
||||||
|
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||||
|
yield
|
||||||
|
set_auth_config(AuthConfig(jwt_secret=_JWT_SECRET))
|
||||||
|
|
||||||
|
|
||||||
|
def _req(cookies=None, method="GET", headers=None):
|
||||||
|
return SimpleNamespace(cookies=cookies or {}, method=method, headers=headers or {})
|
||||||
|
|
||||||
|
|
||||||
|
def _user(user_id=None, token_version=0):
|
||||||
|
return User(email="test@example.com", password_hash="fakehash", system_role="user", id=user_id or uuid4(), token_version=token_version)
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_provider(user=None):
|
||||||
|
p = AsyncMock()
|
||||||
|
p.get_user = AsyncMock(return_value=user)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
# ── @auth.authenticate ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_cookie_raises_401():
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req()))
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
assert "Not authenticated" in str(exc.value.detail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_jwt_raises_401():
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req({"access_token": "garbage"})))
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
assert "Token error" in str(exc.value.detail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_expired_jwt_raises_401():
|
||||||
|
token = create_access_token("user-1", expires_delta=timedelta(seconds=-1))
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req({"access_token": token})))
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_user_not_found_raises_401():
|
||||||
|
token = create_access_token("ghost")
|
||||||
|
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(None)):
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req({"access_token": token})))
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
assert "User not found" in str(exc.value.detail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_version_mismatch_raises_401():
|
||||||
|
user = _user(token_version=2)
|
||||||
|
token = create_access_token(str(user.id), token_version=1)
|
||||||
|
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req({"access_token": token})))
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
assert "revoked" in str(exc.value.detail).lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_token_returns_user_id():
|
||||||
|
user = _user(token_version=0)
|
||||||
|
token = create_access_token(str(user.id), token_version=0)
|
||||||
|
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||||
|
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||||
|
assert result == str(user.id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_token_matching_version():
|
||||||
|
user = _user(token_version=5)
|
||||||
|
token = create_access_token(str(user.id), token_version=5)
|
||||||
|
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||||
|
result = asyncio.run(authenticate(_req({"access_token": token})))
|
||||||
|
assert result == str(user.id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── @auth.authenticate edge cases ────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_exception_propagates():
|
||||||
|
"""Provider raises → should not be swallowed silently."""
|
||||||
|
token = create_access_token("user-1")
|
||||||
|
p = AsyncMock()
|
||||||
|
p.get_user = AsyncMock(side_effect=RuntimeError("DB down"))
|
||||||
|
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=p):
|
||||||
|
with pytest.raises(RuntimeError, match="DB down"):
|
||||||
|
asyncio.run(authenticate(_req({"access_token": token})))
|
||||||
|
|
||||||
|
|
||||||
|
def test_jwt_missing_ver_defaults_to_zero():
|
||||||
|
"""JWT without 'ver' claim → decoded as ver=0, matches user with token_version=0."""
|
||||||
|
import jwt as pyjwt
|
||||||
|
|
||||||
|
uid = str(uuid4())
|
||||||
|
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||||
|
user = _user(user_id=uid, token_version=0)
|
||||||
|
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||||
|
result = asyncio.run(authenticate(_req({"access_token": raw})))
|
||||||
|
assert result == uid
|
||||||
|
|
||||||
|
|
||||||
|
def test_jwt_missing_ver_rejected_when_user_version_nonzero():
|
||||||
|
"""JWT without 'ver' (defaults 0) vs user with token_version=1 → 401."""
|
||||||
|
import jwt as pyjwt
|
||||||
|
|
||||||
|
uid = str(uuid4())
|
||||||
|
raw = pyjwt.encode({"sub": uid, "exp": 9999999999, "iat": 1000000000}, _JWT_SECRET, algorithm="HS256")
|
||||||
|
user = _user(user_id=uid, token_version=1)
|
||||||
|
with patch("app.gateway.langgraph_auth.get_local_provider", return_value=_mock_provider(user)):
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_secret_raises_401():
|
||||||
|
"""Token signed with different secret → 401."""
|
||||||
|
import jwt as pyjwt
|
||||||
|
|
||||||
|
raw = pyjwt.encode({"sub": "user-1", "exp": 9999999999, "ver": 0}, "wrong-secret-that-is-long-enough-32chars!", algorithm="HS256")
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req({"access_token": raw})))
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── @auth.on (owner filter) ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeUser:
|
||||||
|
"""Minimal BaseUser-compatible object without langgraph_api.config dependency."""
|
||||||
|
|
||||||
|
def __init__(self, identity: str):
|
||||||
|
self.identity = identity
|
||||||
|
self.is_authenticated = True
|
||||||
|
self.display_name = identity
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ctx(user_id):
|
||||||
|
return Auth.types.AuthContext(resource="threads", action="create", user=_FakeUser(user_id), permissions=[])
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_injects_user_id():
|
||||||
|
value = {}
|
||||||
|
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||||
|
assert value["metadata"]["owner_id"] == "user-a"
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_preserves_existing_metadata():
|
||||||
|
value = {"metadata": {"title": "hello"}}
|
||||||
|
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||||
|
assert value["metadata"]["owner_id"] == "user-a"
|
||||||
|
assert value["metadata"]["title"] == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_returns_user_id_dict():
|
||||||
|
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||||
|
assert result == {"owner_id": "user-x"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_read_write_consistency():
|
||||||
|
value = {}
|
||||||
|
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||||
|
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_different_users_different_filters():
|
||||||
|
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||||
|
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||||
|
assert f_a["owner_id"] != f_b["owner_id"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_overrides_conflicting_user_id():
|
||||||
|
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||||
|
value = {"metadata": {"owner_id": "attacker"}}
|
||||||
|
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||||
|
assert value["metadata"]["owner_id"] == "real-owner"
|
||||||
|
|
||||||
|
|
||||||
|
def test_filter_with_empty_metadata():
|
||||||
|
"""Explicit empty metadata dict is fine."""
|
||||||
|
value = {"metadata": {}}
|
||||||
|
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||||
|
assert value["metadata"]["owner_id"] == "user-z"
|
||||||
|
assert result == {"owner_id": "user-z"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_shared_jwt_secret():
|
||||||
|
token = create_access_token("user-1", token_version=3)
|
||||||
|
payload = decode_token(token)
|
||||||
|
from app.gateway.auth.errors import TokenError
|
||||||
|
|
||||||
|
assert not isinstance(payload, TokenError)
|
||||||
|
assert payload.sub == "user-1"
|
||||||
|
assert payload.ver == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_langgraph_json_has_auth_path():
|
||||||
|
import json
|
||||||
|
|
||||||
|
config = json.loads((Path(__file__).parent.parent / "langgraph.json").read_text())
|
||||||
|
assert "auth" in config
|
||||||
|
assert "langgraph_auth" in config["auth"]["path"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_handler_has_both_layers():
|
||||||
|
from app.gateway.langgraph_auth import auth
|
||||||
|
|
||||||
|
assert auth._authenticate_handler is not None
|
||||||
|
assert len(auth._global_handlers) == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── CSRF in LangGraph auth ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_get_no_check():
|
||||||
|
"""GET requests skip CSRF — should proceed to JWT validation."""
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req(method="GET")))
|
||||||
|
# Rejected by missing cookie, NOT by CSRF
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
assert "Not authenticated" in str(exc.value.detail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_post_missing_token():
|
||||||
|
"""POST without CSRF token → 403."""
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req(method="POST", cookies={"access_token": "some-jwt"})))
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
assert "CSRF token missing" in str(exc.value.detail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_post_mismatched_token():
|
||||||
|
"""POST with mismatched CSRF tokens → 403."""
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(
|
||||||
|
authenticate(
|
||||||
|
_req(
|
||||||
|
method="POST",
|
||||||
|
cookies={"access_token": "some-jwt", "csrf_token": "real-token"},
|
||||||
|
headers={"x-csrf-token": "wrong-token"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
assert "mismatch" in str(exc.value.detail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_post_matching_token_proceeds_to_jwt():
|
||||||
|
"""POST with matching CSRF tokens passes CSRF check, then fails on JWT."""
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(
|
||||||
|
authenticate(
|
||||||
|
_req(
|
||||||
|
method="POST",
|
||||||
|
cookies={"access_token": "garbage", "csrf_token": "same-token"},
|
||||||
|
headers={"x-csrf-token": "same-token"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# Past CSRF, rejected by JWT decode
|
||||||
|
assert exc.value.status_code == 401
|
||||||
|
assert "Token error" in str(exc.value.detail)
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_put_requires_token():
|
||||||
|
"""PUT also requires CSRF."""
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req(method="PUT", cookies={"access_token": "jwt"})))
|
||||||
|
assert exc.value.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_csrf_delete_requires_token():
|
||||||
|
"""DELETE also requires CSRF."""
|
||||||
|
with pytest.raises(Auth.exceptions.HTTPException) as exc:
|
||||||
|
asyncio.run(authenticate(_req(method="DELETE", cookies={"access_token": "jwt"})))
|
||||||
|
assert exc.value.status_code == 403
|
||||||
@@ -52,7 +52,6 @@
|
|||||||
"@xyflow/react": "^12.10.0",
|
"@xyflow/react": "^12.10.0",
|
||||||
"ai": "^6.0.33",
|
"ai": "^6.0.33",
|
||||||
"best-effort-json-parser": "^1.2.1",
|
"best-effort-json-parser": "^1.2.1",
|
||||||
"better-auth": "^1.3",
|
|
||||||
"canvas-confetti": "^1.9.4",
|
"canvas-confetti": "^1.9.4",
|
||||||
"class-variance-authority": "^0.7.1",
|
"class-variance-authority": "^0.7.1",
|
||||||
"clsx": "^2.1.1",
|
"clsx": "^2.1.1",
|
||||||
|
|||||||
Generated
-183
@@ -113,9 +113,6 @@ importers:
|
|||||||
best-effort-json-parser:
|
best-effort-json-parser:
|
||||||
specifier: ^1.2.1
|
specifier: ^1.2.1
|
||||||
version: 1.2.1
|
version: 1.2.1
|
||||||
better-auth:
|
|
||||||
specifier: ^1.3
|
|
||||||
version: 1.4.18(next@16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(vue@3.5.28(typescript@5.9.3))
|
|
||||||
canvas-confetti:
|
canvas-confetti:
|
||||||
specifier: ^1.9.4
|
specifier: ^1.9.4
|
||||||
version: 1.9.4
|
version: 1.9.4
|
||||||
@@ -317,27 +314,6 @@ packages:
|
|||||||
resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==}
|
resolution: {integrity: sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==}
|
||||||
engines: {node: '>=6.9.0'}
|
engines: {node: '>=6.9.0'}
|
||||||
|
|
||||||
'@better-auth/core@1.4.18':
|
|
||||||
resolution: {integrity: sha512-q+awYgC7nkLEBdx2sW0iJjkzgSHlIxGnOpsN1r/O1+a4m7osJNHtfK2mKJSL1I+GfNyIlxJF8WvD/NLuYMpmcg==}
|
|
||||||
peerDependencies:
|
|
||||||
'@better-auth/utils': 0.3.0
|
|
||||||
'@better-fetch/fetch': 1.1.21
|
|
||||||
better-call: 1.1.8
|
|
||||||
jose: ^6.1.0
|
|
||||||
kysely: ^0.28.5
|
|
||||||
nanostores: ^1.0.1
|
|
||||||
|
|
||||||
'@better-auth/telemetry@1.4.18':
|
|
||||||
resolution: {integrity: sha512-e5rDF8S4j3Um/0LIVATL2in9dL4lfO2fr2v1Wio4qTMRbfxqnUDTa+6SZtwdeJrbc4O+a3c+IyIpjG9Q/6GpfQ==}
|
|
||||||
peerDependencies:
|
|
||||||
'@better-auth/core': 1.4.18
|
|
||||||
|
|
||||||
'@better-auth/utils@0.3.0':
|
|
||||||
resolution: {integrity: sha512-W+Adw6ZA6mgvnSnhOki270rwJ42t4XzSK6YWGF//BbVXL6SwCLWfyzBc1lN2m/4RM28KubdBKQ4X5VMoLRNPQw==}
|
|
||||||
|
|
||||||
'@better-fetch/fetch@1.1.21':
|
|
||||||
resolution: {integrity: sha512-/ImESw0sskqlVR94jB+5+Pxjf+xBwDZF/N5+y2/q4EqD7IARUTSpPfIo8uf39SYpCxyOCtbyYpUrZ3F/k0zT4A==}
|
|
||||||
|
|
||||||
'@braintree/sanitize-url@7.1.2':
|
'@braintree/sanitize-url@7.1.2':
|
||||||
resolution: {integrity: sha512-jigsZK+sMF/cuiB7sERuo9V7N9jx+dhmHHnQyDSVdpZwVutaBu7WvNYqMDLSgFgfB30n452TP3vjDAvFC973mA==}
|
resolution: {integrity: sha512-jigsZK+sMF/cuiB7sERuo9V7N9jx+dhmHHnQyDSVdpZwVutaBu7WvNYqMDLSgFgfB30n452TP3vjDAvFC973mA==}
|
||||||
|
|
||||||
@@ -1116,14 +1092,6 @@ packages:
|
|||||||
cpu: [x64]
|
cpu: [x64]
|
||||||
os: [win32]
|
os: [win32]
|
||||||
|
|
||||||
'@noble/ciphers@2.1.1':
|
|
||||||
resolution: {integrity: sha512-bysYuiVfhxNJuldNXlFEitTVdNnYUc+XNJZd7Qm2a5j1vZHgY+fazadNFWFaMK/2vye0JVlxV3gHmC0WDfAOQw==}
|
|
||||||
engines: {node: '>= 20.19.0'}
|
|
||||||
|
|
||||||
'@noble/hashes@2.0.1':
|
|
||||||
resolution: {integrity: sha512-XlOlEbQcE9fmuXxrVTXCTlG2nlRXa9Rj3rr5Ue/+tX+nmkgbX720YHh0VR3hBF9xDvwnb8D2shVGOwNx+ulArw==}
|
|
||||||
engines: {node: '>= 20.19.0'}
|
|
||||||
|
|
||||||
'@nodelib/fs.scandir@2.1.5':
|
'@nodelib/fs.scandir@2.1.5':
|
||||||
resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==}
|
resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==}
|
||||||
engines: {node: '>= 8'}
|
engines: {node: '>= 8'}
|
||||||
@@ -2696,76 +2664,6 @@ packages:
|
|||||||
best-effort-json-parser@1.2.1:
|
best-effort-json-parser@1.2.1:
|
||||||
resolution: {integrity: sha512-UICSLibQdzS1f+PBsi3u2YE3SsdXcWicHUg3IMvfuaePS2AYnZJdJeKhGv5OM8/mqJwPt79aDrEJ1oa84tELvw==}
|
resolution: {integrity: sha512-UICSLibQdzS1f+PBsi3u2YE3SsdXcWicHUg3IMvfuaePS2AYnZJdJeKhGv5OM8/mqJwPt79aDrEJ1oa84tELvw==}
|
||||||
|
|
||||||
better-auth@1.4.18:
|
|
||||||
resolution: {integrity: sha512-bnyifLWBPcYVltH3RhS7CM62MoelEqC6Q+GnZwfiDWNfepXoQZBjEvn4urcERC7NTKgKq5zNBM8rvPvRBa6xcg==}
|
|
||||||
peerDependencies:
|
|
||||||
'@lynx-js/react': '*'
|
|
||||||
'@prisma/client': ^5.0.0 || ^6.0.0 || ^7.0.0
|
|
||||||
'@sveltejs/kit': ^2.0.0
|
|
||||||
'@tanstack/react-start': ^1.0.0
|
|
||||||
'@tanstack/solid-start': ^1.0.0
|
|
||||||
better-sqlite3: ^12.0.0
|
|
||||||
drizzle-kit: '>=0.31.4'
|
|
||||||
drizzle-orm: '>=0.41.0'
|
|
||||||
mongodb: ^6.0.0 || ^7.0.0
|
|
||||||
mysql2: ^3.0.0
|
|
||||||
next: ^14.0.0 || ^15.0.0 || ^16.0.0
|
|
||||||
pg: ^8.0.0
|
|
||||||
prisma: ^5.0.0 || ^6.0.0 || ^7.0.0
|
|
||||||
react: ^18.0.0 || ^19.0.0
|
|
||||||
react-dom: ^18.0.0 || ^19.0.0
|
|
||||||
solid-js: ^1.0.0
|
|
||||||
svelte: ^4.0.0 || ^5.0.0
|
|
||||||
vitest: ^2.0.0 || ^3.0.0 || ^4.0.0
|
|
||||||
vue: ^3.0.0
|
|
||||||
peerDependenciesMeta:
|
|
||||||
'@lynx-js/react':
|
|
||||||
optional: true
|
|
||||||
'@prisma/client':
|
|
||||||
optional: true
|
|
||||||
'@sveltejs/kit':
|
|
||||||
optional: true
|
|
||||||
'@tanstack/react-start':
|
|
||||||
optional: true
|
|
||||||
'@tanstack/solid-start':
|
|
||||||
optional: true
|
|
||||||
better-sqlite3:
|
|
||||||
optional: true
|
|
||||||
drizzle-kit:
|
|
||||||
optional: true
|
|
||||||
drizzle-orm:
|
|
||||||
optional: true
|
|
||||||
mongodb:
|
|
||||||
optional: true
|
|
||||||
mysql2:
|
|
||||||
optional: true
|
|
||||||
next:
|
|
||||||
optional: true
|
|
||||||
pg:
|
|
||||||
optional: true
|
|
||||||
prisma:
|
|
||||||
optional: true
|
|
||||||
react:
|
|
||||||
optional: true
|
|
||||||
react-dom:
|
|
||||||
optional: true
|
|
||||||
solid-js:
|
|
||||||
optional: true
|
|
||||||
svelte:
|
|
||||||
optional: true
|
|
||||||
vitest:
|
|
||||||
optional: true
|
|
||||||
vue:
|
|
||||||
optional: true
|
|
||||||
|
|
||||||
better-call@1.1.8:
|
|
||||||
resolution: {integrity: sha512-XMQ2rs6FNXasGNfMjzbyroSwKwYbZ/T3IxruSS6U2MJRsSYh3wYtG3o6H00ZlKZ/C/UPOAD97tqgQJNsxyeTXw==}
|
|
||||||
peerDependencies:
|
|
||||||
zod: ^4.0.0
|
|
||||||
peerDependenciesMeta:
|
|
||||||
zod:
|
|
||||||
optional: true
|
|
||||||
|
|
||||||
better-react-mathjax@2.3.0:
|
better-react-mathjax@2.3.0:
|
||||||
resolution: {integrity: sha512-K0ceQC+jQmB+NLDogO5HCpqmYf18AU2FxDbLdduYgkHYWZApFggkHE4dIaXCV1NqeoscESYXXo1GSkY6fA295w==}
|
resolution: {integrity: sha512-K0ceQC+jQmB+NLDogO5HCpqmYf18AU2FxDbLdduYgkHYWZApFggkHE4dIaXCV1NqeoscESYXXo1GSkY6fA295w==}
|
||||||
peerDependencies:
|
peerDependencies:
|
||||||
@@ -3973,9 +3871,6 @@ packages:
|
|||||||
resolution: {integrity: sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==}
|
resolution: {integrity: sha512-ekilCSN1jwRvIbgeg/57YFh8qQDNbwDb9xT/qu2DAHbFFZUicIl4ygVaAvzveMhMVr3LnpSKTNnwt8PoOfmKhQ==}
|
||||||
hasBin: true
|
hasBin: true
|
||||||
|
|
||||||
jose@6.1.3:
|
|
||||||
resolution: {integrity: sha512-0TpaTfihd4QMNwrz/ob2Bp7X04yuxJkjRGi4aKmOqwhov54i6u79oCv7T+C7lo70MKH6BesI3vscD1yb/yzKXQ==}
|
|
||||||
|
|
||||||
js-tiktoken@1.0.21:
|
js-tiktoken@1.0.21:
|
||||||
resolution: {integrity: sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==}
|
resolution: {integrity: sha512-biOj/6M5qdgx5TKjDnFT1ymSpM5tbd3ylwDtrQvFQSu0Z7bBYko2dF+W/aUkXUPuk6IVpRxk/3Q2sHOzGlS36g==}
|
||||||
|
|
||||||
@@ -4026,10 +3921,6 @@ packages:
|
|||||||
knitwork@1.3.0:
|
knitwork@1.3.0:
|
||||||
resolution: {integrity: sha512-4LqMNoONzR43B1W0ek0fhXMsDNW/zxa1NdFAVMY+k28pgZLovR4G3PB5MrpTxCy1QaZCqNoiaKPr5w5qZHfSNw==}
|
resolution: {integrity: sha512-4LqMNoONzR43B1W0ek0fhXMsDNW/zxa1NdFAVMY+k28pgZLovR4G3PB5MrpTxCy1QaZCqNoiaKPr5w5qZHfSNw==}
|
||||||
|
|
||||||
kysely@0.28.11:
|
|
||||||
resolution: {integrity: sha512-zpGIFg0HuoC893rIjYX1BETkVWdDnzTzF5e0kWXJFg5lE0k1/LfNWBejrcnOFu8Q2Rfq/hTDTU7XLUM8QOrpzg==}
|
|
||||||
engines: {node: '>=20.0.0'}
|
|
||||||
|
|
||||||
langium@3.3.1:
|
langium@3.3.1:
|
||||||
resolution: {integrity: sha512-QJv/h939gDpvT+9SiLVlY7tZC3xB2qK57v0J04Sh9wpMb6MP1q8gB21L3WIo8T5P1MSMg3Ep14L7KkDCFG3y4w==}
|
resolution: {integrity: sha512-QJv/h939gDpvT+9SiLVlY7tZC3xB2qK57v0J04Sh9wpMb6MP1q8gB21L3WIo8T5P1MSMg3Ep14L7KkDCFG3y4w==}
|
||||||
engines: {node: '>=16.0.0'}
|
engines: {node: '>=16.0.0'}
|
||||||
@@ -4458,10 +4349,6 @@ packages:
|
|||||||
engines: {node: ^18 || >=20}
|
engines: {node: ^18 || >=20}
|
||||||
hasBin: true
|
hasBin: true
|
||||||
|
|
||||||
nanostores@1.1.0:
|
|
||||||
resolution: {integrity: sha512-yJBmDJr18xy47dbNVlHcgdPrulSn1nhSE6Ns9vTG+Nx9VPT6iV1MD6aQFp/t52zpf82FhLLTXAXr30NuCnxvwA==}
|
|
||||||
engines: {node: ^20.0.0 || >=22.0.0}
|
|
||||||
|
|
||||||
napi-postinstall@0.3.4:
|
napi-postinstall@0.3.4:
|
||||||
resolution: {integrity: sha512-PHI5f1O0EP5xJ9gQmFGMS6IZcrVvTjpXjz7Na41gTE7eE2hK11lg04CECCYEEjdc17EV4DO+fkGEtt7TpTaTiQ==}
|
resolution: {integrity: sha512-PHI5f1O0EP5xJ9gQmFGMS6IZcrVvTjpXjz7Na41gTE7eE2hK11lg04CECCYEEjdc17EV4DO+fkGEtt7TpTaTiQ==}
|
||||||
engines: {node: ^12.20.0 || ^14.18.0 || >=16.0.0}
|
engines: {node: ^12.20.0 || ^14.18.0 || >=16.0.0}
|
||||||
@@ -5050,9 +4937,6 @@ packages:
|
|||||||
engines: {node: '>=18.0.0', npm: '>=8.0.0'}
|
engines: {node: '>=18.0.0', npm: '>=8.0.0'}
|
||||||
hasBin: true
|
hasBin: true
|
||||||
|
|
||||||
rou3@0.7.12:
|
|
||||||
resolution: {integrity: sha512-iFE4hLDuloSWcD7mjdCDhx2bKcIsYbtOTpfH5MHHLSKMOUyjqQXTeZVa289uuwEGEKFoE/BAPbhaU4B774nceg==}
|
|
||||||
|
|
||||||
roughjs@4.6.6:
|
roughjs@4.6.6:
|
||||||
resolution: {integrity: sha512-ZUz/69+SYpFN/g/lUlo2FXcIjRkSu3nDarreVdGGndHEBJ6cXPdKguS8JGxwj5HA5xIbVKSmLgr5b3AWxtRfvQ==}
|
resolution: {integrity: sha512-ZUz/69+SYpFN/g/lUlo2FXcIjRkSu3nDarreVdGGndHEBJ6cXPdKguS8JGxwj5HA5xIbVKSmLgr5b3AWxtRfvQ==}
|
||||||
|
|
||||||
@@ -5105,9 +4989,6 @@ packages:
|
|||||||
server-only@0.0.1:
|
server-only@0.0.1:
|
||||||
resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==}
|
resolution: {integrity: sha512-qepMx2JxAa5jjfzxG79yPPq+8BuFToHd1hm7kI+Z4zAq1ftQiP7HcxMhDDItrbtwVeLg/cY2JnKnrcFkmiswNA==}
|
||||||
|
|
||||||
set-cookie-parser@2.7.2:
|
|
||||||
resolution: {integrity: sha512-oeM1lpU/UvhTxw+g3cIfxXHyJRc/uidd3yK1P242gzHds0udQBYzs3y8j4gCCW+ZJ7ad0yctld8RYO+bdurlvw==}
|
|
||||||
|
|
||||||
set-function-length@1.2.2:
|
set-function-length@1.2.2:
|
||||||
resolution: {integrity: sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==}
|
resolution: {integrity: sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==}
|
||||||
engines: {node: '>= 0.4'}
|
engines: {node: '>= 0.4'}
|
||||||
@@ -5802,27 +5683,6 @@ snapshots:
|
|||||||
'@babel/helper-string-parser': 7.27.1
|
'@babel/helper-string-parser': 7.27.1
|
||||||
'@babel/helper-validator-identifier': 7.28.5
|
'@babel/helper-validator-identifier': 7.28.5
|
||||||
|
|
||||||
'@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)':
|
|
||||||
dependencies:
|
|
||||||
'@better-auth/utils': 0.3.0
|
|
||||||
'@better-fetch/fetch': 1.1.21
|
|
||||||
'@standard-schema/spec': 1.1.0
|
|
||||||
better-call: 1.1.8(zod@4.3.6)
|
|
||||||
jose: 6.1.3
|
|
||||||
kysely: 0.28.11
|
|
||||||
nanostores: 1.1.0
|
|
||||||
zod: 4.3.6
|
|
||||||
|
|
||||||
'@better-auth/telemetry@1.4.18(@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0))':
|
|
||||||
dependencies:
|
|
||||||
'@better-auth/core': 1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)
|
|
||||||
'@better-auth/utils': 0.3.0
|
|
||||||
'@better-fetch/fetch': 1.1.21
|
|
||||||
|
|
||||||
'@better-auth/utils@0.3.0': {}
|
|
||||||
|
|
||||||
'@better-fetch/fetch@1.1.21': {}
|
|
||||||
|
|
||||||
'@braintree/sanitize-url@7.1.2': {}
|
'@braintree/sanitize-url@7.1.2': {}
|
||||||
|
|
||||||
'@cfworker/json-schema@4.1.1': {}
|
'@cfworker/json-schema@4.1.1': {}
|
||||||
@@ -6671,10 +6531,6 @@ snapshots:
|
|||||||
'@next/swc-win32-x64-msvc@16.1.7':
|
'@next/swc-win32-x64-msvc@16.1.7':
|
||||||
optional: true
|
optional: true
|
||||||
|
|
||||||
'@noble/ciphers@2.1.1': {}
|
|
||||||
|
|
||||||
'@noble/hashes@2.0.1': {}
|
|
||||||
|
|
||||||
'@nodelib/fs.scandir@2.1.5':
|
'@nodelib/fs.scandir@2.1.5':
|
||||||
dependencies:
|
dependencies:
|
||||||
'@nodelib/fs.stat': 2.0.5
|
'@nodelib/fs.stat': 2.0.5
|
||||||
@@ -8242,35 +8098,6 @@ snapshots:
|
|||||||
|
|
||||||
best-effort-json-parser@1.2.1: {}
|
best-effort-json-parser@1.2.1: {}
|
||||||
|
|
||||||
better-auth@1.4.18(next@16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(vue@3.5.28(typescript@5.9.3)):
|
|
||||||
dependencies:
|
|
||||||
'@better-auth/core': 1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0)
|
|
||||||
'@better-auth/telemetry': 1.4.18(@better-auth/core@1.4.18(@better-auth/utils@0.3.0)(@better-fetch/fetch@1.1.21)(better-call@1.1.8(zod@3.25.76))(jose@6.1.3)(kysely@0.28.11)(nanostores@1.1.0))
|
|
||||||
'@better-auth/utils': 0.3.0
|
|
||||||
'@better-fetch/fetch': 1.1.21
|
|
||||||
'@noble/ciphers': 2.1.1
|
|
||||||
'@noble/hashes': 2.0.1
|
|
||||||
better-call: 1.1.8(zod@4.3.6)
|
|
||||||
defu: 6.1.4
|
|
||||||
jose: 6.1.3
|
|
||||||
kysely: 0.28.11
|
|
||||||
nanostores: 1.1.0
|
|
||||||
zod: 4.3.6
|
|
||||||
optionalDependencies:
|
|
||||||
next: 16.1.7(@opentelemetry/api@1.9.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)
|
|
||||||
react: 19.2.4
|
|
||||||
react-dom: 19.2.4(react@19.2.4)
|
|
||||||
vue: 3.5.28(typescript@5.9.3)
|
|
||||||
|
|
||||||
better-call@1.1.8(zod@4.3.6):
|
|
||||||
dependencies:
|
|
||||||
'@better-auth/utils': 0.3.0
|
|
||||||
'@better-fetch/fetch': 1.1.21
|
|
||||||
rou3: 0.7.12
|
|
||||||
set-cookie-parser: 2.7.2
|
|
||||||
optionalDependencies:
|
|
||||||
zod: 4.3.6
|
|
||||||
|
|
||||||
better-react-mathjax@2.3.0(react@19.2.4):
|
better-react-mathjax@2.3.0(react@19.2.4):
|
||||||
dependencies:
|
dependencies:
|
||||||
mathjax-full: 3.2.2
|
mathjax-full: 3.2.2
|
||||||
@@ -9786,8 +9613,6 @@ snapshots:
|
|||||||
|
|
||||||
jiti@2.6.1: {}
|
jiti@2.6.1: {}
|
||||||
|
|
||||||
jose@6.1.3: {}
|
|
||||||
|
|
||||||
js-tiktoken@1.0.21:
|
js-tiktoken@1.0.21:
|
||||||
dependencies:
|
dependencies:
|
||||||
base64-js: 1.5.1
|
base64-js: 1.5.1
|
||||||
@@ -9833,8 +9658,6 @@ snapshots:
|
|||||||
|
|
||||||
knitwork@1.3.0: {}
|
knitwork@1.3.0: {}
|
||||||
|
|
||||||
kysely@0.28.11: {}
|
|
||||||
|
|
||||||
langium@3.3.1:
|
langium@3.3.1:
|
||||||
dependencies:
|
dependencies:
|
||||||
chevrotain: 11.0.3
|
chevrotain: 11.0.3
|
||||||
@@ -10529,8 +10352,6 @@ snapshots:
|
|||||||
|
|
||||||
nanoid@5.1.6: {}
|
nanoid@5.1.6: {}
|
||||||
|
|
||||||
nanostores@1.1.0: {}
|
|
||||||
|
|
||||||
napi-postinstall@0.3.4: {}
|
napi-postinstall@0.3.4: {}
|
||||||
|
|
||||||
natural-compare@1.4.0: {}
|
natural-compare@1.4.0: {}
|
||||||
@@ -11305,8 +11126,6 @@ snapshots:
|
|||||||
'@rollup/rollup-win32-x64-msvc': 4.60.0
|
'@rollup/rollup-win32-x64-msvc': 4.60.0
|
||||||
fsevents: 2.3.3
|
fsevents: 2.3.3
|
||||||
|
|
||||||
rou3@0.7.12: {}
|
|
||||||
|
|
||||||
roughjs@4.6.6:
|
roughjs@4.6.6:
|
||||||
dependencies:
|
dependencies:
|
||||||
hachure-fill: 0.5.2
|
hachure-fill: 0.5.2
|
||||||
@@ -11373,8 +11192,6 @@ snapshots:
|
|||||||
|
|
||||||
server-only@0.0.1: {}
|
server-only@0.0.1: {}
|
||||||
|
|
||||||
set-cookie-parser@2.7.2: {}
|
|
||||||
|
|
||||||
set-function-length@1.2.2:
|
set-function-length@1.2.2:
|
||||||
dependencies:
|
dependencies:
|
||||||
define-data-property: 1.1.4
|
define-data-property: 1.1.4
|
||||||
|
|||||||
@@ -0,0 +1,45 @@
|
|||||||
|
import Link from "next/link";
|
||||||
|
import { redirect } from "next/navigation";
|
||||||
|
import { type ReactNode } from "react";
|
||||||
|
|
||||||
|
import { AuthProvider } from "@/core/auth/AuthProvider";
|
||||||
|
import { getServerSideUser } from "@/core/auth/server";
|
||||||
|
import { assertNever } from "@/core/auth/types";
|
||||||
|
|
||||||
|
export const dynamic = "force-dynamic";
|
||||||
|
|
||||||
|
export default async function AuthLayout({
|
||||||
|
children,
|
||||||
|
}: {
|
||||||
|
children: ReactNode;
|
||||||
|
}) {
|
||||||
|
const result = await getServerSideUser();
|
||||||
|
|
||||||
|
switch (result.tag) {
|
||||||
|
case "authenticated":
|
||||||
|
redirect("/workspace");
|
||||||
|
case "needs_setup":
|
||||||
|
// Allow access to setup page
|
||||||
|
return <AuthProvider initialUser={result.user}>{children}</AuthProvider>;
|
||||||
|
case "unauthenticated":
|
||||||
|
return <AuthProvider initialUser={null}>{children}</AuthProvider>;
|
||||||
|
case "gateway_unavailable":
|
||||||
|
return (
|
||||||
|
<div className="flex h-screen flex-col items-center justify-center gap-4">
|
||||||
|
<p className="text-muted-foreground">
|
||||||
|
Service temporarily unavailable.
|
||||||
|
</p>
|
||||||
|
<Link
|
||||||
|
href="/login"
|
||||||
|
className="bg-primary text-primary-foreground hover:bg-primary/90 rounded-md px-4 py-2 text-sm"
|
||||||
|
>
|
||||||
|
Retry
|
||||||
|
</Link>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
case "config_error":
|
||||||
|
throw new Error(result.message);
|
||||||
|
default:
|
||||||
|
assertNever(result);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import Link from "next/link";
|
||||||
|
import { useRouter, useSearchParams } from "next/navigation";
|
||||||
|
import { useEffect, useState } from "react";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { useAuth } from "@/core/auth/AuthProvider";
|
||||||
|
import { parseAuthError } from "@/core/auth/types";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validate next parameter
|
||||||
|
* Prevent open redirect attacks
|
||||||
|
* Per RFC-001: Only allow relative paths starting with /
|
||||||
|
*/
|
||||||
|
function validateNextParam(next: string | null): string | null {
|
||||||
|
if (!next) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Need start with / (relative path)
|
||||||
|
if (!next.startsWith("/")) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disallow protocol-relative URLs
|
||||||
|
if (
|
||||||
|
next.startsWith("//") ||
|
||||||
|
next.startsWith("http://") ||
|
||||||
|
next.startsWith("https://")
|
||||||
|
) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disallow URLs with different protocols (e.g., javascript:, data:, etc)
|
||||||
|
if (next.includes(":") && !next.startsWith("/")) {
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Valid relative path
|
||||||
|
return next;
|
||||||
|
}
|
||||||
|
|
||||||
|
export default function LoginPage() {
|
||||||
|
const router = useRouter();
|
||||||
|
const searchParams = useSearchParams();
|
||||||
|
const { isAuthenticated } = useAuth();
|
||||||
|
|
||||||
|
const [email, setEmail] = useState("");
|
||||||
|
const [password, setPassword] = useState("");
|
||||||
|
const [isLogin, setIsLogin] = useState(true);
|
||||||
|
const [error, setError] = useState("");
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
|
||||||
|
// Get next parameter for validated redirect
|
||||||
|
const nextParam = searchParams.get("next");
|
||||||
|
const redirectPath = validateNextParam(nextParam) ?? "/workspace";
|
||||||
|
|
||||||
|
// Redirect if already authenticated (client-side, post-login)
|
||||||
|
useEffect(() => {
|
||||||
|
if (isAuthenticated) {
|
||||||
|
router.push(redirectPath);
|
||||||
|
}
|
||||||
|
}, [isAuthenticated, redirectPath, router]);
|
||||||
|
|
||||||
|
const handleSubmit = async (e: React.FormEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
setError("");
|
||||||
|
setLoading(true);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const endpoint = isLogin
|
||||||
|
? "/api/v1/auth/login/local"
|
||||||
|
: "/api/v1/auth/register";
|
||||||
|
const body = isLogin
|
||||||
|
? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}`
|
||||||
|
: JSON.stringify({ email, password });
|
||||||
|
|
||||||
|
const headers: HeadersInit = isLogin
|
||||||
|
? { "Content-Type": "application/x-www-form-urlencoded" }
|
||||||
|
: { "Content-Type": "application/json" };
|
||||||
|
|
||||||
|
const res = await fetch(endpoint, {
|
||||||
|
method: "POST",
|
||||||
|
headers,
|
||||||
|
body,
|
||||||
|
credentials: "include", // Important: include HttpOnly cookie
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!res.ok) {
|
||||||
|
const data = await res.json();
|
||||||
|
const authError = parseAuthError(data);
|
||||||
|
setError(authError.message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Both login and register set a cookie — redirect to workspace
|
||||||
|
router.push(redirectPath);
|
||||||
|
} catch (_err) {
|
||||||
|
setError("Network error. Please try again.");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex min-h-screen items-center justify-center bg-[#0a0a0a]">
|
||||||
|
<div className="border-border/20 w-full max-w-md space-y-6 rounded-lg border bg-black/50 p-8 backdrop-blur-sm">
|
||||||
|
<div className="text-center">
|
||||||
|
<h1 className="font-serif text-3xl">DeerFlow</h1>
|
||||||
|
<p className="text-muted-foreground mt-2">
|
||||||
|
{isLogin ? "Sign in to your account" : "Create a new account"}
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<form onSubmit={handleSubmit} className="space-y-4">
|
||||||
|
<div>
|
||||||
|
<label htmlFor="email" className="text-sm font-medium">
|
||||||
|
Email
|
||||||
|
</label>
|
||||||
|
<Input
|
||||||
|
id="email"
|
||||||
|
type="email"
|
||||||
|
value={email}
|
||||||
|
onChange={(e) => setEmail(e.target.value)}
|
||||||
|
placeholder="you@example.com"
|
||||||
|
required
|
||||||
|
className="mt-1 bg-white text-black"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div>
|
||||||
|
<label htmlFor="password" className="text-sm font-medium">
|
||||||
|
Password
|
||||||
|
</label>
|
||||||
|
<Input
|
||||||
|
id="password"
|
||||||
|
type="password"
|
||||||
|
value={password}
|
||||||
|
onChange={(e) => setPassword(e.target.value)}
|
||||||
|
placeholder="•••••••"
|
||||||
|
required
|
||||||
|
minLength={isLogin ? 6 : 8}
|
||||||
|
className="mt-1 bg-white text-black"
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
{error && <p className="text-sm text-red-500">{error}</p>}
|
||||||
|
|
||||||
|
<Button type="submit" className="w-full" disabled={loading}>
|
||||||
|
{loading
|
||||||
|
? "Please wait..."
|
||||||
|
: isLogin
|
||||||
|
? "Sign In"
|
||||||
|
: "Create Account"}
|
||||||
|
</Button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<div className="text-center text-sm">
|
||||||
|
<button
|
||||||
|
type="button"
|
||||||
|
onClick={() => {
|
||||||
|
setIsLogin(!isLogin);
|
||||||
|
setError("");
|
||||||
|
}}
|
||||||
|
className="text-blue-500 hover:underline"
|
||||||
|
>
|
||||||
|
{isLogin
|
||||||
|
? "Don't have an account? Sign up"
|
||||||
|
: "Already have an account? Sign in"}
|
||||||
|
</button>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div className="text-muted-foreground text-center text-xs">
|
||||||
|
<Link href="/" className="hover:underline">
|
||||||
|
← Back to home
|
||||||
|
</Link>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useRouter } from "next/navigation";
|
||||||
|
import { useState } from "react";
|
||||||
|
|
||||||
|
import { Button } from "@/components/ui/button";
|
||||||
|
import { Input } from "@/components/ui/input";
|
||||||
|
import { getCsrfHeaders } from "@/core/api/fetcher";
|
||||||
|
import { parseAuthError } from "@/core/auth/types";
|
||||||
|
|
||||||
|
export default function SetupPage() {
|
||||||
|
const router = useRouter();
|
||||||
|
const [email, setEmail] = useState("");
|
||||||
|
const [newPassword, setNewPassword] = useState("");
|
||||||
|
const [confirmPassword, setConfirmPassword] = useState("");
|
||||||
|
const [currentPassword, setCurrentPassword] = useState("");
|
||||||
|
const [error, setError] = useState("");
|
||||||
|
const [loading, setLoading] = useState(false);
|
||||||
|
|
||||||
|
const handleSetup = async (e: React.FormEvent) => {
|
||||||
|
e.preventDefault();
|
||||||
|
setError("");
|
||||||
|
|
||||||
|
if (newPassword !== confirmPassword) {
|
||||||
|
setError("Passwords do not match");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
if (newPassword.length < 8) {
|
||||||
|
setError("Password must be at least 8 characters");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
setLoading(true);
|
||||||
|
try {
|
||||||
|
const res = await fetch("/api/v1/auth/change-password", {
|
||||||
|
method: "POST",
|
||||||
|
headers: {
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
...getCsrfHeaders(),
|
||||||
|
},
|
||||||
|
credentials: "include",
|
||||||
|
body: JSON.stringify({
|
||||||
|
current_password: currentPassword,
|
||||||
|
new_password: newPassword,
|
||||||
|
new_email: email || undefined,
|
||||||
|
}),
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!res.ok) {
|
||||||
|
const data = await res.json();
|
||||||
|
const authError = parseAuthError(data);
|
||||||
|
setError(authError.message);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
router.push("/workspace");
|
||||||
|
} catch {
|
||||||
|
setError("Network error. Please try again.");
|
||||||
|
} finally {
|
||||||
|
setLoading(false);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
return (
|
||||||
|
<div className="flex min-h-screen items-center justify-center">
|
||||||
|
<div className="w-full max-w-sm space-y-6 p-6">
|
||||||
|
<div className="text-center">
|
||||||
|
<h1 className="font-serif text-3xl">DeerFlow</h1>
|
||||||
|
<p className="text-muted-foreground mt-2">
|
||||||
|
Complete admin account setup
|
||||||
|
</p>
|
||||||
|
<p className="text-muted-foreground mt-1 text-xs">
|
||||||
|
Set your real email and a new password.
|
||||||
|
</p>
|
||||||
|
</div>
|
||||||
|
<form onSubmit={handleSetup} className="space-y-4">
|
||||||
|
<Input
|
||||||
|
type="email"
|
||||||
|
placeholder="Your email"
|
||||||
|
value={email}
|
||||||
|
onChange={(e) => setEmail(e.target.value)}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
placeholder="Current password (from console log)"
|
||||||
|
value={currentPassword}
|
||||||
|
onChange={(e) => setCurrentPassword(e.target.value)}
|
||||||
|
required
|
||||||
|
/>
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
placeholder="New password"
|
||||||
|
value={newPassword}
|
||||||
|
onChange={(e) => setNewPassword(e.target.value)}
|
||||||
|
required
|
||||||
|
minLength={8}
|
||||||
|
/>
|
||||||
|
<Input
|
||||||
|
type="password"
|
||||||
|
placeholder="Confirm new password"
|
||||||
|
value={confirmPassword}
|
||||||
|
onChange={(e) => setConfirmPassword(e.target.value)}
|
||||||
|
required
|
||||||
|
minLength={8}
|
||||||
|
/>
|
||||||
|
{error && <p className="text-sm text-red-500">{error}</p>}
|
||||||
|
<Button type="submit" className="w-full" disabled={loading}>
|
||||||
|
{loading ? "Setting up..." : "Complete Setup"}
|
||||||
|
</Button>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
import { toNextJsHandler } from "better-auth/next-js";
|
|
||||||
|
|
||||||
import { auth } from "@/server/better-auth";
|
|
||||||
|
|
||||||
export const { GET, POST } = toNextJsHandler(auth.handler);
|
|
||||||
@@ -1,47 +1,58 @@
|
|||||||
"use client";
|
import Link from "next/link";
|
||||||
|
import { redirect } from "next/navigation";
|
||||||
|
|
||||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
import { AuthProvider } from "@/core/auth/AuthProvider";
|
||||||
import { useCallback, useEffect, useLayoutEffect, useState } from "react";
|
import { getServerSideUser } from "@/core/auth/server";
|
||||||
import { Toaster } from "sonner";
|
import { assertNever } from "@/core/auth/types";
|
||||||
|
|
||||||
import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar";
|
import { WorkspaceContent } from "./workspace-content";
|
||||||
import { CommandPalette } from "@/components/workspace/command-palette";
|
|
||||||
import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar";
|
|
||||||
import { getLocalSettings, useLocalSettings } from "@/core/settings";
|
|
||||||
|
|
||||||
const queryClient = new QueryClient();
|
export const dynamic = "force-dynamic";
|
||||||
|
|
||||||
export default function WorkspaceLayout({
|
export default async function WorkspaceLayout({
|
||||||
children,
|
children,
|
||||||
}: Readonly<{ children: React.ReactNode }>) {
|
}: Readonly<{ children: React.ReactNode }>) {
|
||||||
const [settings, setSettings] = useLocalSettings();
|
const result = await getServerSideUser();
|
||||||
const [open, setOpen] = useState(false); // SSR default: open (matches server render)
|
|
||||||
useLayoutEffect(() => {
|
switch (result.tag) {
|
||||||
// Runs synchronously before first paint on the client — no visual flash
|
case "authenticated":
|
||||||
setOpen(!getLocalSettings().layout.sidebar_collapsed);
|
return (
|
||||||
}, []);
|
<AuthProvider initialUser={result.user}>
|
||||||
useEffect(() => {
|
<WorkspaceContent>{children}</WorkspaceContent>
|
||||||
setOpen(!settings.layout.sidebar_collapsed);
|
</AuthProvider>
|
||||||
}, [settings.layout.sidebar_collapsed]);
|
);
|
||||||
const handleOpenChange = useCallback(
|
case "needs_setup":
|
||||||
(open: boolean) => {
|
redirect("/setup");
|
||||||
setOpen(open);
|
case "unauthenticated":
|
||||||
setSettings("layout", { sidebar_collapsed: !open });
|
redirect("/login");
|
||||||
},
|
case "gateway_unavailable":
|
||||||
[setSettings],
|
return (
|
||||||
);
|
<div className="flex h-screen flex-col items-center justify-center gap-4">
|
||||||
return (
|
<p className="text-muted-foreground">
|
||||||
<QueryClientProvider client={queryClient}>
|
Service temporarily unavailable.
|
||||||
<SidebarProvider
|
</p>
|
||||||
className="h-screen"
|
<p className="text-muted-foreground text-xs">
|
||||||
open={open}
|
The backend may be restarting. Please wait a moment and try again.
|
||||||
onOpenChange={handleOpenChange}
|
</p>
|
||||||
>
|
<div className="flex gap-3">
|
||||||
<WorkspaceSidebar />
|
<Link
|
||||||
<SidebarInset className="min-w-0">{children}</SidebarInset>
|
href="/workspace"
|
||||||
</SidebarProvider>
|
className="bg-primary text-primary-foreground hover:bg-primary/90 rounded-md px-4 py-2 text-sm"
|
||||||
<CommandPalette />
|
>
|
||||||
<Toaster position="top-center" />
|
Retry
|
||||||
</QueryClientProvider>
|
</Link>
|
||||||
);
|
<Link
|
||||||
|
href="/api/v1/auth/logout"
|
||||||
|
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
|
||||||
|
>
|
||||||
|
Logout & Reset
|
||||||
|
</Link>
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
);
|
||||||
|
case "config_error":
|
||||||
|
throw new Error(result.message);
|
||||||
|
default:
|
||||||
|
assertNever(result);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||||
|
import { useCallback, useEffect, useLayoutEffect, useState } from "react";
|
||||||
|
import { Toaster } from "sonner";
|
||||||
|
|
||||||
|
import { SidebarInset, SidebarProvider } from "@/components/ui/sidebar";
|
||||||
|
import { CommandPalette } from "@/components/workspace/command-palette";
|
||||||
|
import { WorkspaceSidebar } from "@/components/workspace/workspace-sidebar";
|
||||||
|
import { getLocalSettings, useLocalSettings } from "@/core/settings";
|
||||||
|
|
||||||
|
export function WorkspaceContent({
|
||||||
|
children,
|
||||||
|
}: Readonly<{ children: React.ReactNode }>) {
|
||||||
|
const [queryClient] = useState(() => new QueryClient());
|
||||||
|
const [settings, setSettings] = useLocalSettings();
|
||||||
|
const [open, setOpen] = useState(false); // SSR default: open (matches server render)
|
||||||
|
|
||||||
|
useLayoutEffect(() => {
|
||||||
|
// Runs synchronously before first paint on the client — no visual flash
|
||||||
|
setOpen(!getLocalSettings().layout.sidebar_collapsed);
|
||||||
|
}, []);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
setOpen(!settings.layout.sidebar_collapsed);
|
||||||
|
}, [settings.layout.sidebar_collapsed]);
|
||||||
|
|
||||||
|
const handleOpenChange = useCallback(
|
||||||
|
(open: boolean) => {
|
||||||
|
setOpen(open);
|
||||||
|
setSettings("layout", { sidebar_collapsed: !open });
|
||||||
|
},
|
||||||
|
[setSettings],
|
||||||
|
);
|
||||||
|
|
||||||
|
return (
|
||||||
|
<QueryClientProvider client={queryClient}>
|
||||||
|
<SidebarProvider
|
||||||
|
className="h-screen"
|
||||||
|
open={open}
|
||||||
|
onOpenChange={handleOpenChange}
|
||||||
|
>
|
||||||
|
<WorkspaceSidebar />
|
||||||
|
<SidebarInset className="min-w-0">{children}</SidebarInset>
|
||||||
|
</SidebarProvider>
|
||||||
|
<CommandPalette />
|
||||||
|
<Toaster position="top-center" />
|
||||||
|
</QueryClientProvider>
|
||||||
|
);
|
||||||
|
}
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
import { buildLoginUrl } from "@/core/auth/types";
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch with credentials. Automatically redirects to login on 401.
|
||||||
|
*/
|
||||||
|
export async function fetchWithAuth(
|
||||||
|
input: RequestInfo | string,
|
||||||
|
init?: RequestInit,
|
||||||
|
): Promise<Response> {
|
||||||
|
const url = typeof input === "string" ? input : input.url;
|
||||||
|
const res = await fetch(url, {
|
||||||
|
...init,
|
||||||
|
credentials: "include",
|
||||||
|
});
|
||||||
|
|
||||||
|
if (res.status === 401) {
|
||||||
|
window.location.href = buildLoginUrl(window.location.pathname);
|
||||||
|
throw new Error("Unauthorized");
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build headers for CSRF-protected requests
|
||||||
|
* Per RFC-001: Double Submit Cookie pattern
|
||||||
|
*/
|
||||||
|
export function getCsrfHeaders(): HeadersInit {
|
||||||
|
const token = getCsrfToken();
|
||||||
|
return token ? { "X-CSRF-Token": token } : {};
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get CSRF token from cookie
|
||||||
|
*/
|
||||||
|
function getCsrfToken(): string | null {
|
||||||
|
const match = /csrf_token=([^;]+)/.exec(document.cookie);
|
||||||
|
return match?.[1] ?? null;
|
||||||
|
}
|
||||||
@@ -0,0 +1,165 @@
|
|||||||
|
"use client";
|
||||||
|
|
||||||
|
import { useRouter, usePathname } from "next/navigation";
|
||||||
|
import React, {
|
||||||
|
createContext,
|
||||||
|
useContext,
|
||||||
|
useState,
|
||||||
|
useCallback,
|
||||||
|
useEffect,
|
||||||
|
type ReactNode,
|
||||||
|
} from "react";
|
||||||
|
|
||||||
|
import { type User, buildLoginUrl } from "./types";
|
||||||
|
|
||||||
|
// Re-export for consumers
|
||||||
|
export type { User };
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Authentication context provided to consuming components
|
||||||
|
*/
|
||||||
|
interface AuthContextType {
|
||||||
|
user: User | null;
|
||||||
|
isAuthenticated: boolean;
|
||||||
|
isLoading: boolean;
|
||||||
|
logout: () => Promise<void>;
|
||||||
|
refreshUser: () => Promise<void>;
|
||||||
|
}
|
||||||
|
|
||||||
|
const AuthContext = createContext<AuthContextType | undefined>(undefined);
|
||||||
|
|
||||||
|
interface AuthProviderProps {
|
||||||
|
children: ReactNode;
|
||||||
|
initialUser: User | null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AuthProvider - Unified authentication context for the application
|
||||||
|
*
|
||||||
|
* Per RFC-001:
|
||||||
|
* - Only holds display information (user), never JWT or tokens
|
||||||
|
* - initialUser comes from server-side guard, avoiding client flicker
|
||||||
|
* - Provides logout and refresh capabilities
|
||||||
|
*/
|
||||||
|
export function AuthProvider({ children, initialUser }: AuthProviderProps) {
|
||||||
|
const [user, setUser] = useState<User | null>(initialUser);
|
||||||
|
const [isLoading, setIsLoading] = useState(false);
|
||||||
|
const router = useRouter();
|
||||||
|
const pathname = usePathname();
|
||||||
|
|
||||||
|
const isAuthenticated = user !== null;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch current user from FastAPI
|
||||||
|
* Used when initialUser might be stale (e.g., after tab was inactive)
|
||||||
|
*/
|
||||||
|
const refreshUser = useCallback(async () => {
|
||||||
|
try {
|
||||||
|
setIsLoading(true);
|
||||||
|
const res = await fetch("/api/v1/auth/me", {
|
||||||
|
credentials: "include",
|
||||||
|
});
|
||||||
|
|
||||||
|
if (res.ok) {
|
||||||
|
const data = await res.json();
|
||||||
|
setUser(data);
|
||||||
|
} else if (res.status === 401) {
|
||||||
|
// Session expired or invalid
|
||||||
|
setUser(null);
|
||||||
|
// Redirect to login if on a protected route
|
||||||
|
if (pathname?.startsWith("/workspace")) {
|
||||||
|
router.push(buildLoginUrl(pathname));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Failed to refresh user:", err);
|
||||||
|
setUser(null);
|
||||||
|
} finally {
|
||||||
|
setIsLoading(false);
|
||||||
|
}
|
||||||
|
}, [pathname, router]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Logout - call FastAPI logout endpoint and clear local state
|
||||||
|
* Per RFC-001: Immediately clear local state, don't wait for server confirmation
|
||||||
|
*/
|
||||||
|
const logout = useCallback(async () => {
|
||||||
|
// Immediately clear local state to prevent UI flicker
|
||||||
|
setUser(null);
|
||||||
|
|
||||||
|
try {
|
||||||
|
await fetch("/api/v1/auth/logout", {
|
||||||
|
method: "POST",
|
||||||
|
credentials: "include",
|
||||||
|
});
|
||||||
|
} catch (err) {
|
||||||
|
console.error("Logout request failed:", err);
|
||||||
|
// Still redirect even if logout request fails
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect to home page
|
||||||
|
router.push("/");
|
||||||
|
}, [router]);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle visibility change - refresh user when tab becomes visible again.
|
||||||
|
* Throttled to at most once per 60 s to avoid spamming the backend on rapid tab switches.
|
||||||
|
*/
|
||||||
|
const lastCheckRef = React.useRef(0);
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
const handleVisibilityChange = () => {
|
||||||
|
if (document.visibilityState !== "visible" || user === null) return;
|
||||||
|
const now = Date.now();
|
||||||
|
if (now - lastCheckRef.current < 60_000) return;
|
||||||
|
lastCheckRef.current = now;
|
||||||
|
void refreshUser();
|
||||||
|
};
|
||||||
|
|
||||||
|
document.addEventListener("visibilitychange", handleVisibilityChange);
|
||||||
|
return () => {
|
||||||
|
document.removeEventListener("visibilitychange", handleVisibilityChange);
|
||||||
|
};
|
||||||
|
}, [user, refreshUser]);
|
||||||
|
|
||||||
|
const value: AuthContextType = {
|
||||||
|
user,
|
||||||
|
isAuthenticated,
|
||||||
|
isLoading,
|
||||||
|
logout,
|
||||||
|
refreshUser,
|
||||||
|
};
|
||||||
|
|
||||||
|
return <AuthContext.Provider value={value}>{children}</AuthContext.Provider>;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hook to access authentication context
|
||||||
|
* Throws if used outside AuthProvider - this is intentional for proper usage
|
||||||
|
*/
|
||||||
|
export function useAuth(): AuthContextType {
|
||||||
|
const context = useContext(AuthContext);
|
||||||
|
if (context === undefined) {
|
||||||
|
throw new Error("useAuth must be used within an AuthProvider");
|
||||||
|
}
|
||||||
|
return context;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Hook to require authentication - redirects to login if not authenticated
|
||||||
|
* Useful for client-side checks in addition to server-side guards
|
||||||
|
*/
|
||||||
|
export function useRequireAuth(): AuthContextType {
|
||||||
|
const auth = useAuth();
|
||||||
|
const router = useRouter();
|
||||||
|
const pathname = usePathname();
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
// Only redirect if we're sure user is not authenticated (not just loading)
|
||||||
|
if (!auth.isLoading && !auth.isAuthenticated) {
|
||||||
|
router.push(buildLoginUrl(pathname || "/workspace"));
|
||||||
|
}
|
||||||
|
}, [auth.isAuthenticated, auth.isLoading, router, pathname]);
|
||||||
|
|
||||||
|
return auth;
|
||||||
|
}
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
const gatewayConfigSchema = z.object({
|
||||||
|
internalGatewayUrl: z.string().url(),
|
||||||
|
trustedOrigins: z.array(z.string()).min(1),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type GatewayConfig = z.infer<typeof gatewayConfigSchema>;
|
||||||
|
|
||||||
|
let _cached: GatewayConfig | null = null;
|
||||||
|
|
||||||
|
export function getGatewayConfig(): GatewayConfig {
|
||||||
|
if (_cached) return _cached;
|
||||||
|
|
||||||
|
const isDev = process.env.NODE_ENV === "development";
|
||||||
|
|
||||||
|
const rawUrl = process.env.DEER_FLOW_INTERNAL_GATEWAY_BASE_URL?.trim();
|
||||||
|
const internalGatewayUrl =
|
||||||
|
rawUrl?.replace(/\/+$/, "") ??
|
||||||
|
(isDev ? "http://localhost:8001" : undefined);
|
||||||
|
|
||||||
|
const rawOrigins = process.env.DEER_FLOW_TRUSTED_ORIGINS?.trim();
|
||||||
|
const trustedOrigins = rawOrigins
|
||||||
|
? rawOrigins
|
||||||
|
.split(",")
|
||||||
|
.map((s) => s.trim())
|
||||||
|
.filter(Boolean)
|
||||||
|
: isDev
|
||||||
|
? ["http://localhost:3000"]
|
||||||
|
: undefined;
|
||||||
|
|
||||||
|
_cached = gatewayConfigSchema.parse({ internalGatewayUrl, trustedOrigins });
|
||||||
|
return _cached;
|
||||||
|
}
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
export interface ProxyPolicy {
|
||||||
|
/** Allowed upstream path prefixes */
|
||||||
|
readonly allowedPaths: readonly string[];
|
||||||
|
/** Request headers to strip before forwarding */
|
||||||
|
readonly strippedRequestHeaders: ReadonlySet<string>;
|
||||||
|
/** Response headers to strip before returning */
|
||||||
|
readonly strippedResponseHeaders: ReadonlySet<string>;
|
||||||
|
/** Credential mode: which cookie to forward */
|
||||||
|
readonly credential: { readonly type: "cookie"; readonly name: string };
|
||||||
|
/** Timeout in ms */
|
||||||
|
readonly timeoutMs: number;
|
||||||
|
/** CSRF: required for non-GET/HEAD */
|
||||||
|
readonly csrf: boolean;
|
||||||
|
}
|
||||||
|
|
||||||
|
export const LANGGRAPH_COMPAT_POLICY: ProxyPolicy = {
|
||||||
|
allowedPaths: [
|
||||||
|
"threads",
|
||||||
|
"runs",
|
||||||
|
"assistants",
|
||||||
|
"store",
|
||||||
|
"models",
|
||||||
|
"mcp",
|
||||||
|
"skills",
|
||||||
|
"memory",
|
||||||
|
],
|
||||||
|
strippedRequestHeaders: new Set([
|
||||||
|
"host",
|
||||||
|
"connection",
|
||||||
|
"keep-alive",
|
||||||
|
"transfer-encoding",
|
||||||
|
"te",
|
||||||
|
"trailer",
|
||||||
|
"upgrade",
|
||||||
|
"authorization",
|
||||||
|
"x-api-key",
|
||||||
|
"origin",
|
||||||
|
"referer",
|
||||||
|
"proxy-authorization",
|
||||||
|
"proxy-authenticate",
|
||||||
|
]),
|
||||||
|
strippedResponseHeaders: new Set([
|
||||||
|
"connection",
|
||||||
|
"keep-alive",
|
||||||
|
"transfer-encoding",
|
||||||
|
"te",
|
||||||
|
"trailer",
|
||||||
|
"upgrade",
|
||||||
|
"content-length",
|
||||||
|
"set-cookie",
|
||||||
|
]),
|
||||||
|
credential: { type: "cookie", name: "access_token" },
|
||||||
|
timeoutMs: 120_000,
|
||||||
|
csrf: true,
|
||||||
|
};
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
import { cookies } from "next/headers";
|
||||||
|
|
||||||
|
import { getGatewayConfig } from "./gateway-config";
|
||||||
|
import { type AuthResult, userSchema } from "./types";
|
||||||
|
|
||||||
|
const SSR_AUTH_TIMEOUT_MS = 5_000;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch the authenticated user from the gateway using the request's cookies.
|
||||||
|
* Returns a tagged AuthResult — callers use exhaustive switch, no try/catch.
|
||||||
|
*/
|
||||||
|
export async function getServerSideUser(): Promise<AuthResult> {
|
||||||
|
const cookieStore = await cookies();
|
||||||
|
const sessionCookie = cookieStore.get("access_token");
|
||||||
|
|
||||||
|
let internalGatewayUrl: string;
|
||||||
|
try {
|
||||||
|
internalGatewayUrl = getGatewayConfig().internalGatewayUrl;
|
||||||
|
} catch (err) {
|
||||||
|
return { tag: "config_error", message: String(err) };
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!sessionCookie) return { tag: "unauthenticated" };
|
||||||
|
|
||||||
|
const controller = new AbortController();
|
||||||
|
const timeout = setTimeout(() => controller.abort(), SSR_AUTH_TIMEOUT_MS);
|
||||||
|
|
||||||
|
try {
|
||||||
|
const res = await fetch(`${internalGatewayUrl}/api/v1/auth/me`, {
|
||||||
|
headers: { Cookie: `access_token=${sessionCookie.value}` },
|
||||||
|
cache: "no-store",
|
||||||
|
signal: controller.signal,
|
||||||
|
});
|
||||||
|
clearTimeout(timeout); // Clear immediately — covers all response branches
|
||||||
|
|
||||||
|
if (res.ok) {
|
||||||
|
const parsed = userSchema.safeParse(await res.json());
|
||||||
|
if (!parsed.success) {
|
||||||
|
console.error("[SSR auth] Malformed /auth/me response:", parsed.error);
|
||||||
|
return { tag: "gateway_unavailable" };
|
||||||
|
}
|
||||||
|
if (parsed.data.needs_setup) {
|
||||||
|
return { tag: "needs_setup", user: parsed.data };
|
||||||
|
}
|
||||||
|
return { tag: "authenticated", user: parsed.data };
|
||||||
|
}
|
||||||
|
if (res.status === 401 || res.status === 403) {
|
||||||
|
return { tag: "unauthenticated" };
|
||||||
|
}
|
||||||
|
console.error(`[SSR auth] /api/v1/auth/me responded ${res.status}`);
|
||||||
|
return { tag: "gateway_unavailable" };
|
||||||
|
} catch (err) {
|
||||||
|
clearTimeout(timeout);
|
||||||
|
console.error("[SSR auth] Failed to reach gateway:", err);
|
||||||
|
return { tag: "gateway_unavailable" };
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,72 @@
|
|||||||
|
import { z } from "zod";
|
||||||
|
|
||||||
|
// ── User schema (single source of truth) ──────────────────────────
|
||||||
|
|
||||||
|
export const userSchema = z.object({
|
||||||
|
id: z.string(),
|
||||||
|
email: z.string().email(),
|
||||||
|
system_role: z.enum(["admin", "user"]),
|
||||||
|
needs_setup: z.boolean().optional().default(false),
|
||||||
|
});
|
||||||
|
|
||||||
|
export type User = z.infer<typeof userSchema>;
|
||||||
|
|
||||||
|
// ── SSR auth result (tagged union) ────────────────────────────────
|
||||||
|
|
||||||
|
export type AuthResult =
|
||||||
|
| { tag: "authenticated"; user: User }
|
||||||
|
| { tag: "needs_setup"; user: User }
|
||||||
|
| { tag: "unauthenticated" }
|
||||||
|
| { tag: "gateway_unavailable" }
|
||||||
|
| { tag: "config_error"; message: string };
|
||||||
|
|
||||||
|
export function assertNever(x: never): never {
|
||||||
|
throw new Error(`Unexpected auth result: ${JSON.stringify(x)}`);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function buildLoginUrl(returnPath: string): string {
|
||||||
|
return `/login?next=${encodeURIComponent(returnPath)}`;
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Backend error response parsing ────────────────────────────────
|
||||||
|
|
||||||
|
const AUTH_ERROR_CODES = [
|
||||||
|
"invalid_credentials",
|
||||||
|
"token_expired",
|
||||||
|
"token_invalid",
|
||||||
|
"user_not_found",
|
||||||
|
"email_already_exists",
|
||||||
|
"provider_not_found",
|
||||||
|
"not_authenticated",
|
||||||
|
] as const;
|
||||||
|
|
||||||
|
export type AuthErrorCode = (typeof AUTH_ERROR_CODES)[number];
|
||||||
|
|
||||||
|
export interface AuthErrorResponse {
|
||||||
|
code: AuthErrorCode;
|
||||||
|
message: string;
|
||||||
|
}
|
||||||
|
|
||||||
|
const authErrorSchema = z.object({
|
||||||
|
code: z.enum(AUTH_ERROR_CODES),
|
||||||
|
message: z.string(),
|
||||||
|
});
|
||||||
|
|
||||||
|
export function parseAuthError(data: unknown): AuthErrorResponse {
|
||||||
|
// Try top-level {code, message} first
|
||||||
|
const parsed = authErrorSchema.safeParse(data);
|
||||||
|
if (parsed.success) return parsed.data;
|
||||||
|
|
||||||
|
// Unwrap FastAPI's {detail: {code, message}} envelope
|
||||||
|
if (typeof data === "object" && data !== null && "detail" in data) {
|
||||||
|
const detail = (data as Record<string, unknown>).detail;
|
||||||
|
const nested = authErrorSchema.safeParse(detail);
|
||||||
|
if (nested.success) return nested.data;
|
||||||
|
// Legacy string-detail responses
|
||||||
|
if (typeof detail === "string") {
|
||||||
|
return { code: "invalid_credentials", message: detail };
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return { code: "invalid_credentials", message: "Authentication failed" };
|
||||||
|
}
|
||||||
@@ -7,12 +7,6 @@ export const env = createEnv({
|
|||||||
* isn't built with invalid env vars.
|
* isn't built with invalid env vars.
|
||||||
*/
|
*/
|
||||||
server: {
|
server: {
|
||||||
BETTER_AUTH_SECRET:
|
|
||||||
process.env.NODE_ENV === "production"
|
|
||||||
? z.string()
|
|
||||||
: z.string().optional(),
|
|
||||||
BETTER_AUTH_GITHUB_CLIENT_ID: z.string().optional(),
|
|
||||||
BETTER_AUTH_GITHUB_CLIENT_SECRET: z.string().optional(),
|
|
||||||
GITHUB_OAUTH_TOKEN: z.string().optional(),
|
GITHUB_OAUTH_TOKEN: z.string().optional(),
|
||||||
NODE_ENV: z
|
NODE_ENV: z
|
||||||
.enum(["development", "test", "production"])
|
.enum(["development", "test", "production"])
|
||||||
@@ -35,10 +29,6 @@ export const env = createEnv({
|
|||||||
* middlewares) or client-side so we need to destruct manually.
|
* middlewares) or client-side so we need to destruct manually.
|
||||||
*/
|
*/
|
||||||
runtimeEnv: {
|
runtimeEnv: {
|
||||||
BETTER_AUTH_SECRET: process.env.BETTER_AUTH_SECRET,
|
|
||||||
BETTER_AUTH_GITHUB_CLIENT_ID: process.env.BETTER_AUTH_GITHUB_CLIENT_ID,
|
|
||||||
BETTER_AUTH_GITHUB_CLIENT_SECRET:
|
|
||||||
process.env.BETTER_AUTH_GITHUB_CLIENT_SECRET,
|
|
||||||
NODE_ENV: process.env.NODE_ENV,
|
NODE_ENV: process.env.NODE_ENV,
|
||||||
|
|
||||||
NEXT_PUBLIC_BACKEND_BASE_URL: process.env.NEXT_PUBLIC_BACKEND_BASE_URL,
|
NEXT_PUBLIC_BACKEND_BASE_URL: process.env.NEXT_PUBLIC_BACKEND_BASE_URL,
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
import { createAuthClient } from "better-auth/react";
|
|
||||||
|
|
||||||
export const authClient = createAuthClient();
|
|
||||||
|
|
||||||
export type Session = typeof authClient.$Infer.Session;
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
import { betterAuth } from "better-auth";
|
|
||||||
|
|
||||||
export const auth = betterAuth({
|
|
||||||
emailAndPassword: {
|
|
||||||
enabled: true,
|
|
||||||
},
|
|
||||||
});
|
|
||||||
|
|
||||||
export type Session = typeof auth.$Infer.Session;
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
export { auth } from "./config";
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
import { headers } from "next/headers";
|
|
||||||
import { cache } from "react";
|
|
||||||
|
|
||||||
import { auth } from ".";
|
|
||||||
|
|
||||||
export const getSession = cache(async () =>
|
|
||||||
auth.api.getSession({ headers: await headers() }),
|
|
||||||
);
|
|
||||||
Reference in New Issue
Block a user