refactor(gateway): remove old auth system and middleware
Remove deprecated authentication and authorization modules: - app/gateway/auth/ - auth providers, JWT, password handling, repositories - app/gateway/auth_middleware.py - authentication middleware - app/gateway/authz.py - authorization module - app/gateway/csrf_middleware.py - CSRF protection middleware - app/gateway/deps.py - old dependency injection - app/gateway/langgraph_auth.py - LangGraph authentication - app/gateway/routers/auth.py - auth API endpoints - app/gateway/routers/assistants_compat.py - assistants compatibility layer These are replaced by the new auth system in packages/storage/. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -1,42 +0,0 @@
|
|||||||
"""Authentication module for DeerFlow.
|
|
||||||
|
|
||||||
This module provides:
|
|
||||||
- JWT-based authentication
|
|
||||||
- Provider Factory pattern for extensible auth methods
|
|
||||||
- UserRepository interface for storage backends (SQLite)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
|
||||||
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
|
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
|
||||||
from app.gateway.auth.models import User, UserResponse
|
|
||||||
from app.gateway.auth.password import hash_password, verify_password
|
|
||||||
from app.gateway.auth.providers import AuthProvider
|
|
||||||
from app.gateway.auth.repositories.base import UserRepository
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
# Config
|
|
||||||
"AuthConfig",
|
|
||||||
"get_auth_config",
|
|
||||||
"set_auth_config",
|
|
||||||
# Errors
|
|
||||||
"AuthErrorCode",
|
|
||||||
"AuthErrorResponse",
|
|
||||||
"TokenError",
|
|
||||||
# JWT
|
|
||||||
"TokenPayload",
|
|
||||||
"create_access_token",
|
|
||||||
"decode_token",
|
|
||||||
# Password
|
|
||||||
"hash_password",
|
|
||||||
"verify_password",
|
|
||||||
# Models
|
|
||||||
"User",
|
|
||||||
"UserResponse",
|
|
||||||
# Providers
|
|
||||||
"AuthProvider",
|
|
||||||
"LocalAuthProvider",
|
|
||||||
# Repository
|
|
||||||
"UserRepository",
|
|
||||||
]
|
|
||||||
@@ -1,57 +0,0 @@
|
|||||||
"""Authentication configuration for DeerFlow."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import secrets
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthConfig(BaseModel):
|
|
||||||
"""JWT and auth-related configuration. Parsed once at startup.
|
|
||||||
|
|
||||||
Note: the ``users`` table now lives in the shared persistence
|
|
||||||
database managed by ``deerflow.persistence.engine``. The old
|
|
||||||
``users_db_path`` config key has been removed — user storage is
|
|
||||||
configured through ``config.database`` like every other table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
jwt_secret: str = Field(
|
|
||||||
...,
|
|
||||||
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
|
|
||||||
)
|
|
||||||
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
|
||||||
oauth_github_client_id: str | None = Field(default=None)
|
|
||||||
oauth_github_client_secret: str | None = Field(default=None)
|
|
||||||
|
|
||||||
|
|
||||||
_auth_config: AuthConfig | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_config() -> AuthConfig:
|
|
||||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
|
||||||
global _auth_config
|
|
||||||
if _auth_config is None:
|
|
||||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
|
||||||
if not jwt_secret:
|
|
||||||
jwt_secret = secrets.token_urlsafe(32)
|
|
||||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
|
||||||
logger.warning(
|
|
||||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
|
||||||
"Sessions will be invalidated on restart. "
|
|
||||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
|
||||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
|
||||||
)
|
|
||||||
_auth_config = AuthConfig(jwt_secret=jwt_secret)
|
|
||||||
return _auth_config
|
|
||||||
|
|
||||||
|
|
||||||
def set_auth_config(config: AuthConfig) -> None:
|
|
||||||
"""Set the global AuthConfig instance (for testing)."""
|
|
||||||
global _auth_config
|
|
||||||
_auth_config = config
|
|
||||||
@@ -1,48 +0,0 @@
|
|||||||
"""Write initial admin credentials to a restricted file instead of logs.
|
|
||||||
|
|
||||||
Logging secrets to stdout/stderr is a well-known CodeQL finding
|
|
||||||
(py/clear-text-logging-sensitive-data) — in production those logs
|
|
||||||
get collected into ELK/Splunk/etc and become a secret sprawl
|
|
||||||
source. This helper writes the credential to a 0600 file that only
|
|
||||||
the process user can read, and returns the path so the caller can
|
|
||||||
log **the path** (not the password) for the operator to pick up.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from deerflow.config.paths import get_paths
|
|
||||||
|
|
||||||
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
|
|
||||||
|
|
||||||
|
|
||||||
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
|
|
||||||
"""Write the admin email + password to ``{base_dir}/admin_initial_credentials.txt``.
|
|
||||||
|
|
||||||
The file is created **atomically** with mode 0600 via ``os.open``
|
|
||||||
so the password is never world-readable, even for the single syscall
|
|
||||||
window between ``write_text`` and ``chmod``.
|
|
||||||
|
|
||||||
``label`` distinguishes "initial" (fresh creation) from "reset"
|
|
||||||
(password reset) in the file header so an operator picking up the
|
|
||||||
file after a restart can tell which event produced it.
|
|
||||||
|
|
||||||
Returns the absolute :class:`Path` to the file.
|
|
||||||
"""
|
|
||||||
target = get_paths().base_dir / _CREDENTIAL_FILENAME
|
|
||||||
target.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
content = (
|
|
||||||
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Atomic 0600 create-or-truncate. O_TRUNC (not O_EXCL) so the
|
|
||||||
# reset-password path can rewrite an existing file without a
|
|
||||||
# separate unlink-then-create dance.
|
|
||||||
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
|
||||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
|
||||||
fh.write(content)
|
|
||||||
|
|
||||||
return target.resolve()
|
|
||||||
@@ -1,45 +0,0 @@
|
|||||||
"""Typed error definitions for auth module.
|
|
||||||
|
|
||||||
AuthErrorCode: exhaustive enum of all auth failure conditions.
|
|
||||||
TokenError: exhaustive enum of JWT decode failures.
|
|
||||||
AuthErrorResponse: structured error payload for HTTP responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from enum import StrEnum
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
|
|
||||||
class AuthErrorCode(StrEnum):
|
|
||||||
"""Exhaustive list of auth error conditions."""
|
|
||||||
|
|
||||||
INVALID_CREDENTIALS = "invalid_credentials"
|
|
||||||
TOKEN_EXPIRED = "token_expired"
|
|
||||||
TOKEN_INVALID = "token_invalid"
|
|
||||||
USER_NOT_FOUND = "user_not_found"
|
|
||||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
|
||||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
|
||||||
NOT_AUTHENTICATED = "not_authenticated"
|
|
||||||
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
|
|
||||||
|
|
||||||
|
|
||||||
class TokenError(StrEnum):
|
|
||||||
"""Exhaustive list of JWT decode failure reasons."""
|
|
||||||
|
|
||||||
EXPIRED = "expired"
|
|
||||||
INVALID_SIGNATURE = "invalid_signature"
|
|
||||||
MALFORMED = "malformed"
|
|
||||||
|
|
||||||
|
|
||||||
class AuthErrorResponse(BaseModel):
|
|
||||||
"""Structured error response — replaces bare `detail` strings."""
|
|
||||||
|
|
||||||
code: AuthErrorCode
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
def token_error_to_code(err: TokenError) -> AuthErrorCode:
|
|
||||||
"""Map TokenError to AuthErrorCode — single source of truth."""
|
|
||||||
if err == TokenError.EXPIRED:
|
|
||||||
return AuthErrorCode.TOKEN_EXPIRED
|
|
||||||
return AuthErrorCode.TOKEN_INVALID
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
"""JWT token creation and verification."""
|
|
||||||
|
|
||||||
from datetime import UTC, datetime, timedelta
|
|
||||||
|
|
||||||
import jwt
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.gateway.auth.config import get_auth_config
|
|
||||||
from app.gateway.auth.errors import TokenError
|
|
||||||
|
|
||||||
|
|
||||||
class TokenPayload(BaseModel):
|
|
||||||
"""JWT token payload."""
|
|
||||||
|
|
||||||
sub: str # user_id
|
|
||||||
exp: datetime
|
|
||||||
iat: datetime | None = None
|
|
||||||
ver: int = 0 # token_version — must match User.token_version
|
|
||||||
|
|
||||||
|
|
||||||
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
|
|
||||||
"""Create a JWT access token.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: The user's UUID as string
|
|
||||||
expires_delta: Optional custom expiry, defaults to 7 days
|
|
||||||
token_version: User's current token_version for invalidation
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Encoded JWT string
|
|
||||||
"""
|
|
||||||
config = get_auth_config()
|
|
||||||
expiry = expires_delta or timedelta(days=config.token_expiry_days)
|
|
||||||
|
|
||||||
now = datetime.now(UTC)
|
|
||||||
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
|
|
||||||
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
|
|
||||||
|
|
||||||
|
|
||||||
def decode_token(token: str) -> TokenPayload | TokenError:
|
|
||||||
"""Decode and validate a JWT token.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
TokenPayload if valid, or a specific TokenError variant.
|
|
||||||
"""
|
|
||||||
config = get_auth_config()
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
|
|
||||||
return TokenPayload(**payload)
|
|
||||||
except jwt.ExpiredSignatureError:
|
|
||||||
return TokenError.EXPIRED
|
|
||||||
except jwt.InvalidSignatureError:
|
|
||||||
return TokenError.INVALID_SIGNATURE
|
|
||||||
except jwt.PyJWTError:
|
|
||||||
return TokenError.MALFORMED
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
"""Local email/password authentication provider."""
|
|
||||||
|
|
||||||
from app.gateway.auth.models import User
|
|
||||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
|
||||||
from app.gateway.auth.providers import AuthProvider
|
|
||||||
from app.gateway.auth.repositories.base import UserRepository
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAuthProvider(AuthProvider):
|
|
||||||
"""Email/password authentication provider using local database."""
|
|
||||||
|
|
||||||
def __init__(self, repository: UserRepository):
|
|
||||||
"""Initialize with a UserRepository.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
repository: UserRepository implementation (SQLite)
|
|
||||||
"""
|
|
||||||
self._repo = repository
|
|
||||||
|
|
||||||
async def authenticate(self, credentials: dict) -> User | None:
|
|
||||||
"""Authenticate with email and password.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
credentials: dict with 'email' and 'password' keys
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if authentication succeeds, None otherwise
|
|
||||||
"""
|
|
||||||
email = credentials.get("email")
|
|
||||||
password = credentials.get("password")
|
|
||||||
|
|
||||||
if not email or not password:
|
|
||||||
return None
|
|
||||||
|
|
||||||
user = await self._repo.get_user_by_email(email)
|
|
||||||
if user is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if user.password_hash is None:
|
|
||||||
# OAuth user without local password
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not await verify_password_async(password, user.password_hash):
|
|
||||||
return None
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def get_user(self, user_id: str) -> User | None:
|
|
||||||
"""Get user by ID."""
|
|
||||||
return await self._repo.get_user_by_id(user_id)
|
|
||||||
|
|
||||||
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
|
|
||||||
"""Create a new local user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
email: User email address
|
|
||||||
password: Plain text password (will be hashed)
|
|
||||||
system_role: Role to assign ("admin" or "user")
|
|
||||||
needs_setup: If True, user must complete setup on first login
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created User instance
|
|
||||||
"""
|
|
||||||
password_hash = await hash_password_async(password) if password else None
|
|
||||||
user = User(
|
|
||||||
email=email,
|
|
||||||
password_hash=password_hash,
|
|
||||||
system_role=system_role,
|
|
||||||
needs_setup=needs_setup,
|
|
||||||
)
|
|
||||||
return await self._repo.create_user(user)
|
|
||||||
|
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
|
||||||
"""Get user by OAuth provider and ID."""
|
|
||||||
return await self._repo.get_user_by_oauth(provider, oauth_id)
|
|
||||||
|
|
||||||
async def count_users(self) -> int:
|
|
||||||
"""Return total number of registered users."""
|
|
||||||
return await self._repo.count_users()
|
|
||||||
|
|
||||||
async def count_admin_users(self) -> int:
|
|
||||||
"""Return number of admin users."""
|
|
||||||
return await self._repo.count_admin_users()
|
|
||||||
|
|
||||||
async def update_user(self, user: User) -> User:
|
|
||||||
"""Update an existing user."""
|
|
||||||
return await self._repo.update_user(user)
|
|
||||||
|
|
||||||
async def get_user_by_email(self, email: str) -> User | None:
|
|
||||||
"""Get user by email."""
|
|
||||||
return await self._repo.get_user_by_email(email)
|
|
||||||
@@ -1,41 +0,0 @@
|
|||||||
"""User Pydantic models for authentication."""
|
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Literal
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
|
||||||
|
|
||||||
|
|
||||||
def _utc_now() -> datetime:
|
|
||||||
"""Return current UTC time (timezone-aware)."""
|
|
||||||
return datetime.now(UTC)
|
|
||||||
|
|
||||||
|
|
||||||
class User(BaseModel):
|
|
||||||
"""Internal user representation."""
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
id: UUID = Field(default_factory=uuid4, description="Primary key")
|
|
||||||
email: EmailStr = Field(..., description="Unique email address")
|
|
||||||
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
|
|
||||||
system_role: Literal["admin", "user"] = Field(default="user")
|
|
||||||
created_at: datetime = Field(default_factory=_utc_now)
|
|
||||||
|
|
||||||
# OAuth linkage (optional)
|
|
||||||
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
|
|
||||||
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
|
||||||
|
|
||||||
# Auth lifecycle
|
|
||||||
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
|
|
||||||
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
|
||||||
|
|
||||||
|
|
||||||
class UserResponse(BaseModel):
|
|
||||||
"""Response model for user info endpoint."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
email: str
|
|
||||||
system_role: Literal["admin", "user"]
|
|
||||||
needs_setup: bool = False
|
|
||||||
@@ -1,33 +0,0 @@
|
|||||||
"""Password hashing utilities using bcrypt directly."""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
|
|
||||||
|
|
||||||
def hash_password(password: str) -> str:
|
|
||||||
"""Hash a password using bcrypt."""
|
|
||||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
|
||||||
"""Verify a password against its hash."""
|
|
||||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
async def hash_password_async(password: str) -> str:
|
|
||||||
"""Hash a password using bcrypt (non-blocking).
|
|
||||||
|
|
||||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
|
||||||
blocking the event loop during password hashing.
|
|
||||||
"""
|
|
||||||
return await asyncio.to_thread(hash_password, password)
|
|
||||||
|
|
||||||
|
|
||||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
|
||||||
"""Verify a password against its hash (non-blocking).
|
|
||||||
|
|
||||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
|
||||||
blocking the event loop during password verification.
|
|
||||||
"""
|
|
||||||
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
"""Auth provider abstraction."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
|
|
||||||
class AuthProvider(ABC):
|
|
||||||
"""Abstract base class for authentication providers."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def authenticate(self, credentials: dict) -> "User | None":
|
|
||||||
"""Authenticate user with given credentials.
|
|
||||||
|
|
||||||
Returns User if authentication succeeds, None otherwise.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_user(self, user_id: str) -> "User | None":
|
|
||||||
"""Retrieve user by ID."""
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# Import User at runtime to avoid circular imports
|
|
||||||
from app.gateway.auth.models import User # noqa: E402
|
|
||||||
@@ -1,102 +0,0 @@
|
|||||||
"""User repository interface for abstracting database operations."""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from app.gateway.auth.models import User
|
|
||||||
|
|
||||||
|
|
||||||
class UserNotFoundError(LookupError):
|
|
||||||
"""Raised when a user repository operation targets a non-existent row.
|
|
||||||
|
|
||||||
Subclass of :class:`LookupError` so callers that already catch
|
|
||||||
``LookupError`` for "missing entity" can keep working unchanged,
|
|
||||||
while specific call sites can pin to this class to distinguish
|
|
||||||
"concurrent delete during update" from other lookups.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class UserRepository(ABC):
|
|
||||||
"""Abstract interface for user data storage.
|
|
||||||
|
|
||||||
Implement this interface to support different storage backends
|
|
||||||
(SQLite)
|
|
||||||
"""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def create_user(self, user: User) -> User:
|
|
||||||
"""Create a new user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: User object to create
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Created User with ID assigned
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If email already exists
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
|
||||||
"""Get user by ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: User UUID as string
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if found, None otherwise
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_user_by_email(self, email: str) -> User | None:
|
|
||||||
"""Get user by email.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
email: User email address
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if found, None otherwise
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def update_user(self, user: User) -> User:
|
|
||||||
"""Update an existing user.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user: User object with updated fields
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Updated User
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
UserNotFoundError: If no row exists for ``user.id``. This is
|
|
||||||
a hard failure (not a no-op) so callers cannot mistake a
|
|
||||||
concurrent-delete race for a successful update.
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def count_users(self) -> int:
|
|
||||||
"""Return total number of registered users."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def count_admin_users(self) -> int:
|
|
||||||
"""Return number of users with system_role == 'admin'."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
|
||||||
"""Get user by OAuth provider and ID.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
provider: OAuth provider name (e.g. 'github', 'google')
|
|
||||||
oauth_id: User ID from the OAuth provider
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
User if found, None otherwise
|
|
||||||
"""
|
|
||||||
...
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
"""SQLAlchemy-backed UserRepository implementation.
|
|
||||||
|
|
||||||
Uses the shared async session factory from
|
|
||||||
``deerflow.persistence.engine`` — the ``users`` table lives in the
|
|
||||||
same database as ``threads_meta``, ``runs``, ``run_events``, and
|
|
||||||
``feedback``.
|
|
||||||
|
|
||||||
Constructor takes the session factory directly (same pattern as the
|
|
||||||
other four repositories in ``deerflow.persistence.*``). Callers
|
|
||||||
construct this after ``init_engine_from_config()`` has run.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import UTC
|
|
||||||
from uuid import UUID
|
|
||||||
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.exc import IntegrityError
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
||||||
|
|
||||||
from app.gateway.auth.models import User
|
|
||||||
from app.gateway.auth.repositories.base import UserNotFoundError, UserRepository
|
|
||||||
from deerflow.persistence.user.model import UserRow
|
|
||||||
|
|
||||||
|
|
||||||
class SQLiteUserRepository(UserRepository):
|
|
||||||
"""Async user repository backed by the shared SQLAlchemy engine."""
|
|
||||||
|
|
||||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
|
||||||
self._sf = session_factory
|
|
||||||
|
|
||||||
# ── Converters ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _row_to_user(row: UserRow) -> User:
|
|
||||||
return User(
|
|
||||||
id=UUID(row.id),
|
|
||||||
email=row.email,
|
|
||||||
password_hash=row.password_hash,
|
|
||||||
system_role=row.system_role, # type: ignore[arg-type]
|
|
||||||
# SQLite loses tzinfo on read; reattach UTC so downstream
|
|
||||||
# code can compare timestamps reliably.
|
|
||||||
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
|
|
||||||
oauth_provider=row.oauth_provider,
|
|
||||||
oauth_id=row.oauth_id,
|
|
||||||
needs_setup=row.needs_setup,
|
|
||||||
token_version=row.token_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _user_to_row(user: User) -> UserRow:
|
|
||||||
return UserRow(
|
|
||||||
id=str(user.id),
|
|
||||||
email=user.email,
|
|
||||||
password_hash=user.password_hash,
|
|
||||||
system_role=user.system_role,
|
|
||||||
created_at=user.created_at,
|
|
||||||
oauth_provider=user.oauth_provider,
|
|
||||||
oauth_id=user.oauth_id,
|
|
||||||
needs_setup=user.needs_setup,
|
|
||||||
token_version=user.token_version,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── CRUD ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def create_user(self, user: User) -> User:
|
|
||||||
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
|
|
||||||
row = self._user_to_row(user)
|
|
||||||
async with self._sf() as session:
|
|
||||||
session.add(row)
|
|
||||||
try:
|
|
||||||
await session.commit()
|
|
||||||
except IntegrityError as exc:
|
|
||||||
await session.rollback()
|
|
||||||
raise ValueError(f"Email already registered: {user.email}") from exc
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
|
||||||
async with self._sf() as session:
|
|
||||||
row = await session.get(UserRow, user_id)
|
|
||||||
return self._row_to_user(row) if row is not None else None
|
|
||||||
|
|
||||||
async def get_user_by_email(self, email: str) -> User | None:
|
|
||||||
stmt = select(UserRow).where(UserRow.email == email)
|
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
return self._row_to_user(row) if row is not None else None
|
|
||||||
|
|
||||||
async def update_user(self, user: User) -> User:
|
|
||||||
async with self._sf() as session:
|
|
||||||
row = await session.get(UserRow, str(user.id))
|
|
||||||
if row is None:
|
|
||||||
# Hard fail on concurrent delete: callers (reset_admin,
|
|
||||||
# password change handlers, _ensure_admin_user) all
|
|
||||||
# fetched the user just before this call, so a missing
|
|
||||||
# row here means the row vanished underneath us. Silent
|
|
||||||
# success would let the caller log "password reset" for
|
|
||||||
# a row that no longer exists.
|
|
||||||
raise UserNotFoundError(f"User {user.id} no longer exists")
|
|
||||||
row.email = user.email
|
|
||||||
row.password_hash = user.password_hash
|
|
||||||
row.system_role = user.system_role
|
|
||||||
row.oauth_provider = user.oauth_provider
|
|
||||||
row.oauth_id = user.oauth_id
|
|
||||||
row.needs_setup = user.needs_setup
|
|
||||||
row.token_version = user.token_version
|
|
||||||
await session.commit()
|
|
||||||
return user
|
|
||||||
|
|
||||||
async def count_users(self) -> int:
|
|
||||||
stmt = select(func.count()).select_from(UserRow)
|
|
||||||
async with self._sf() as session:
|
|
||||||
return await session.scalar(stmt) or 0
|
|
||||||
|
|
||||||
async def count_admin_users(self) -> int:
|
|
||||||
stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin")
|
|
||||||
async with self._sf() as session:
|
|
||||||
return await session.scalar(stmt) or 0
|
|
||||||
|
|
||||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
|
||||||
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
return self._row_to_user(row) if row is not None else None
|
|
||||||
@@ -1,91 +0,0 @@
|
|||||||
"""CLI tool to reset an admin password.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python -m app.gateway.auth.reset_admin
|
|
||||||
python -m app.gateway.auth.reset_admin --email admin@example.com
|
|
||||||
|
|
||||||
Writes the new password to ``.deer-flow/admin_initial_credentials.txt``
|
|
||||||
(mode 0600) instead of printing it, so CI / log aggregators never see
|
|
||||||
the cleartext secret.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import secrets
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from app.gateway.auth.credential_file import write_initial_credentials
|
|
||||||
from app.gateway.auth.password import hash_password
|
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
|
||||||
from deerflow.persistence.user.model import UserRow
|
|
||||||
|
|
||||||
|
|
||||||
async def _run(email: str | None) -> int:
|
|
||||||
from deerflow.config import get_app_config
|
|
||||||
from deerflow.persistence.engine import (
|
|
||||||
close_engine,
|
|
||||||
get_session_factory,
|
|
||||||
init_engine_from_config,
|
|
||||||
)
|
|
||||||
|
|
||||||
config = get_app_config()
|
|
||||||
await init_engine_from_config(config.database)
|
|
||||||
try:
|
|
||||||
sf = get_session_factory()
|
|
||||||
if sf is None:
|
|
||||||
print("Error: persistence engine not available (check config.database).", file=sys.stderr)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
repo = SQLiteUserRepository(sf)
|
|
||||||
|
|
||||||
if email:
|
|
||||||
user = await repo.get_user_by_email(email)
|
|
||||||
else:
|
|
||||||
# Find first admin via direct SELECT — repository does not
|
|
||||||
# expose a "first admin" helper and we do not want to add
|
|
||||||
# one just for this CLI.
|
|
||||||
async with sf() as session:
|
|
||||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
|
||||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
user = None
|
|
||||||
else:
|
|
||||||
user = await repo.get_user_by_id(row.id)
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
if email:
|
|
||||||
print(f"Error: user '{email}' not found.", file=sys.stderr)
|
|
||||||
else:
|
|
||||||
print("Error: no admin user found.", file=sys.stderr)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
new_password = secrets.token_urlsafe(16)
|
|
||||||
user.password_hash = hash_password(new_password)
|
|
||||||
user.token_version += 1
|
|
||||||
user.needs_setup = True
|
|
||||||
await repo.update_user(user)
|
|
||||||
|
|
||||||
cred_path = write_initial_credentials(user.email, new_password, label="reset")
|
|
||||||
print(f"Password reset for: {user.email}")
|
|
||||||
print(f"Credentials written to: {cred_path} (mode 0600)")
|
|
||||||
print("Next login will require setup (new email + password).")
|
|
||||||
return 0
|
|
||||||
finally:
|
|
||||||
await close_engine()
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
parser = argparse.ArgumentParser(description="Reset admin password")
|
|
||||||
parser.add_argument("--email", help="Admin email (default: first admin found)")
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
exit_code = asyncio.run(_run(args.email))
|
|
||||||
sys.exit(exit_code)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
"""Global authentication middleware — fail-closed safety net.
|
|
||||||
|
|
||||||
Rejects unauthenticated requests to non-public paths with 401. When a
|
|
||||||
request passes the cookie check, resolves the JWT payload to a real
|
|
||||||
``User`` object and stamps it into both ``request.state.user`` and the
|
|
||||||
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
|
||||||
owner filtering works automatically via the sentinel pattern.
|
|
||||||
|
|
||||||
Fine-grained permission checks remain in authz.py decorators.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request, Response
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import JSONResponse
|
|
||||||
from starlette.types import ASGIApp
|
|
||||||
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
|
||||||
from app.gateway.authz import _ALL_PERMISSIONS, AuthContext
|
|
||||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
|
||||||
|
|
||||||
# Paths that never require authentication.
|
|
||||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
|
||||||
"/health",
|
|
||||||
"/docs",
|
|
||||||
"/redoc",
|
|
||||||
"/openapi.json",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exact auth paths that are public (login/register/status check).
|
|
||||||
# /api/v1/auth/me, /api/v1/auth/change-password etc. are NOT public.
|
|
||||||
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"/api/v1/auth/login/local",
|
|
||||||
"/api/v1/auth/register",
|
|
||||||
"/api/v1/auth/logout",
|
|
||||||
"/api/v1/auth/setup-status",
|
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_public(path: str) -> bool:
|
|
||||||
stripped = path.rstrip("/")
|
|
||||||
if stripped in _PUBLIC_EXACT_PATHS:
|
|
||||||
return True
|
|
||||||
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
|
||||||
|
|
||||||
|
|
||||||
class AuthMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Strict auth gate: reject requests without a valid session.
|
|
||||||
|
|
||||||
Two-stage check for non-public paths:
|
|
||||||
|
|
||||||
1. Cookie presence — return 401 NOT_AUTHENTICATED if missing
|
|
||||||
2. JWT validation via ``get_optional_user_from_request`` — return 401
|
|
||||||
TOKEN_INVALID if the token is absent, malformed, expired, or the
|
|
||||||
signed user does not exist / is stale
|
|
||||||
|
|
||||||
On success, stamps ``request.state.user`` and the
|
|
||||||
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
|
||||||
owner filters work downstream without every route needing a
|
|
||||||
``@require_auth`` decorator. Routes that need per-resource
|
|
||||||
authorization (e.g. "user A cannot read user B's thread by guessing
|
|
||||||
the URL") should additionally use ``@require_permission(...,
|
|
||||||
owner_check=True)`` for explicit enforcement — but authentication
|
|
||||||
itself is fully handled here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp) -> None:
|
|
||||||
super().__init__(app)
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
||||||
if _is_public(request.url.path):
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
# Non-public path: require session cookie
|
|
||||||
if not request.cookies.get("access_token"):
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=401,
|
|
||||||
content={
|
|
||||||
"detail": AuthErrorResponse(
|
|
||||||
code=AuthErrorCode.NOT_AUTHENTICATED,
|
|
||||||
message="Authentication required",
|
|
||||||
).model_dump()
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Strict JWT validation: reject junk/expired tokens with 401
|
|
||||||
# right here instead of silently passing through. This closes
|
|
||||||
# the "junk cookie bypass" gap (AUTH_TEST_PLAN test 7.5.8):
|
|
||||||
# without this, non-isolation routes like /api/models would
|
|
||||||
# accept any cookie-shaped string as authentication.
|
|
||||||
#
|
|
||||||
# We call the *strict* resolver so that fine-grained error
|
|
||||||
# codes (token_expired, token_invalid, user_not_found, …)
|
|
||||||
# propagate from AuthErrorCode, not get flattened into one
|
|
||||||
# generic code. BaseHTTPMiddleware doesn't let HTTPException
|
|
||||||
# bubble up, so we catch and render it as JSONResponse here.
|
|
||||||
from app.gateway.deps import get_current_user_from_request
|
|
||||||
|
|
||||||
try:
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
except HTTPException as exc:
|
|
||||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
|
||||||
|
|
||||||
# Stamp both request.state.user (for the contextvar pattern)
|
|
||||||
# and request.state.auth (so @require_permission's "auth is
|
|
||||||
# None" branch short-circuits instead of running the entire
|
|
||||||
# JWT-decode + DB-lookup pipeline a second time per request).
|
|
||||||
request.state.user = user
|
|
||||||
request.state.auth = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
|
||||||
token = set_current_user(user)
|
|
||||||
try:
|
|
||||||
return await call_next(request)
|
|
||||||
finally:
|
|
||||||
reset_current_user(token)
|
|
||||||
@@ -1,262 +0,0 @@
|
|||||||
"""Authorization decorators and context for DeerFlow.
|
|
||||||
|
|
||||||
Inspired by LangGraph Auth system: https://github.com/langchain-ai/langgraph/blob/main/libs/sdk-py/langgraph_sdk/auth/__init__.py
|
|
||||||
|
|
||||||
**Usage:**
|
|
||||||
|
|
||||||
1. Use ``@require_auth`` on routes that need authentication
|
|
||||||
2. Use ``@require_permission("resource", "action", filter_key=...)`` for permission checks
|
|
||||||
3. The decorator chain processes from bottom to top
|
|
||||||
|
|
||||||
**Example:**
|
|
||||||
|
|
||||||
@router.get("/{thread_id}")
|
|
||||||
@require_auth
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread(thread_id: str, request: Request):
|
|
||||||
# User is authenticated and has threads:read permission
|
|
||||||
...
|
|
||||||
|
|
||||||
**Permission Model:**
|
|
||||||
|
|
||||||
- threads:read - View thread
|
|
||||||
- threads:write - Create/update thread
|
|
||||||
- threads:delete - Delete thread
|
|
||||||
- runs:create - Run agent
|
|
||||||
- runs:read - View run
|
|
||||||
- runs:cancel - Cancel run
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import functools
|
|
||||||
from collections.abc import Callable
|
|
||||||
from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar
|
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from app.gateway.auth.models import User
|
|
||||||
|
|
||||||
P = ParamSpec("P")
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
# Permission constants
|
|
||||||
class Permissions:
|
|
||||||
"""Permission constants for resource:action format."""
|
|
||||||
|
|
||||||
# Threads
|
|
||||||
THREADS_READ = "threads:read"
|
|
||||||
THREADS_WRITE = "threads:write"
|
|
||||||
THREADS_DELETE = "threads:delete"
|
|
||||||
|
|
||||||
# Runs
|
|
||||||
RUNS_CREATE = "runs:create"
|
|
||||||
RUNS_READ = "runs:read"
|
|
||||||
RUNS_CANCEL = "runs:cancel"
|
|
||||||
|
|
||||||
|
|
||||||
class AuthContext:
|
|
||||||
"""Authentication context for the current request.
|
|
||||||
|
|
||||||
Stored in request.state.auth after require_auth decoration.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
user: The authenticated user, or None if anonymous
|
|
||||||
permissions: List of permission strings (e.g., "threads:read")
|
|
||||||
"""
|
|
||||||
|
|
||||||
__slots__ = ("user", "permissions")
|
|
||||||
|
|
||||||
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
|
||||||
self.user = user
|
|
||||||
self.permissions = permissions or []
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_authenticated(self) -> bool:
|
|
||||||
"""Check if user is authenticated."""
|
|
||||||
return self.user is not None
|
|
||||||
|
|
||||||
def has_permission(self, resource: str, action: str) -> bool:
|
|
||||||
"""Check if context has permission for resource:action.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resource: Resource name (e.g., "threads")
|
|
||||||
action: Action name (e.g., "read")
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if user has permission
|
|
||||||
"""
|
|
||||||
permission = f"{resource}:{action}"
|
|
||||||
return permission in self.permissions
|
|
||||||
|
|
||||||
def require_user(self) -> User:
|
|
||||||
"""Get user or raise 401.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException 401 if not authenticated
|
|
||||||
"""
|
|
||||||
if not self.user:
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication required")
|
|
||||||
return self.user
|
|
||||||
|
|
||||||
|
|
||||||
def get_auth_context(request: Request) -> AuthContext | None:
|
|
||||||
"""Get AuthContext from request state."""
|
|
||||||
return getattr(request.state, "auth", None)
|
|
||||||
|
|
||||||
|
|
||||||
_ALL_PERMISSIONS: list[str] = [
|
|
||||||
Permissions.THREADS_READ,
|
|
||||||
Permissions.THREADS_WRITE,
|
|
||||||
Permissions.THREADS_DELETE,
|
|
||||||
Permissions.RUNS_CREATE,
|
|
||||||
Permissions.RUNS_READ,
|
|
||||||
Permissions.RUNS_CANCEL,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def _authenticate(request: Request) -> AuthContext:
|
|
||||||
"""Authenticate request and return AuthContext.
|
|
||||||
|
|
||||||
Delegates to deps.get_optional_user_from_request() for the JWT→User pipeline.
|
|
||||||
Returns AuthContext with user=None for anonymous requests.
|
|
||||||
"""
|
|
||||||
from app.gateway.deps import get_optional_user_from_request
|
|
||||||
|
|
||||||
user = await get_optional_user_from_request(request)
|
|
||||||
if user is None:
|
|
||||||
return AuthContext(user=None, permissions=[])
|
|
||||||
|
|
||||||
# In future, permissions could be stored in user record
|
|
||||||
return AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
|
||||||
|
|
||||||
|
|
||||||
def require_auth[**P, T](func: Callable[P, T]) -> Callable[P, T]:
|
|
||||||
"""Decorator that authenticates the request and sets AuthContext.
|
|
||||||
|
|
||||||
Must be placed ABOVE other decorators (executes after them).
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
@router.get("/{thread_id}")
|
|
||||||
@require_auth # Bottom decorator (executes first after permission check)
|
|
||||||
@require_permission("threads", "read")
|
|
||||||
async def get_thread(thread_id: str, request: Request):
|
|
||||||
auth: AuthContext = request.state.auth
|
|
||||||
...
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If 'request' parameter is missing
|
|
||||||
"""
|
|
||||||
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
request = kwargs.get("request")
|
|
||||||
if request is None:
|
|
||||||
raise ValueError("require_auth decorator requires 'request' parameter")
|
|
||||||
|
|
||||||
# Authenticate and set context
|
|
||||||
auth_context = await _authenticate(request)
|
|
||||||
request.state.auth = auth_context
|
|
||||||
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
|
|
||||||
def require_permission(
|
|
||||||
resource: str,
|
|
||||||
action: str,
|
|
||||||
owner_check: bool = False,
|
|
||||||
require_existing: bool = False,
|
|
||||||
) -> Callable[[Callable[P, T]], Callable[P, T]]:
|
|
||||||
"""Decorator that checks permission for resource:action.
|
|
||||||
|
|
||||||
Must be used AFTER @require_auth.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
resource: Resource name (e.g., "threads", "runs")
|
|
||||||
action: Action name (e.g., "read", "write", "delete")
|
|
||||||
owner_check: If True, validates that the current user owns the resource.
|
|
||||||
Requires 'thread_id' path parameter and performs ownership check.
|
|
||||||
require_existing: Only meaningful with ``owner_check=True``. If True, a
|
|
||||||
missing ``threads_meta`` row counts as a denial (404)
|
|
||||||
instead of "untracked legacy thread, allow". Use on
|
|
||||||
**destructive / mutating** routes (DELETE, PATCH,
|
|
||||||
state-update) so a deleted thread can't be re-targeted
|
|
||||||
by another user via the missing-row code path.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
# Read-style: legacy untracked threads are allowed
|
|
||||||
@require_permission("threads", "read", owner_check=True)
|
|
||||||
async def get_thread(thread_id: str, request: Request):
|
|
||||||
...
|
|
||||||
|
|
||||||
# Destructive: thread row MUST exist and be owned by caller
|
|
||||||
@require_permission("threads", "delete", owner_check=True, require_existing=True)
|
|
||||||
async def delete_thread(thread_id: str, request: Request):
|
|
||||||
...
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
HTTPException 401: If authentication required but user is anonymous
|
|
||||||
HTTPException 403: If user lacks permission
|
|
||||||
HTTPException 404: If owner_check=True but user doesn't own the thread
|
|
||||||
ValueError: If owner_check=True but 'thread_id' parameter is missing
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(func: Callable[P, T]) -> Callable[P, T]:
|
|
||||||
@functools.wraps(func)
|
|
||||||
async def wrapper(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
request = kwargs.get("request")
|
|
||||||
if request is None:
|
|
||||||
raise ValueError("require_permission decorator requires 'request' parameter")
|
|
||||||
|
|
||||||
auth: AuthContext = getattr(request.state, "auth", None)
|
|
||||||
if auth is None:
|
|
||||||
auth = await _authenticate(request)
|
|
||||||
request.state.auth = auth
|
|
||||||
|
|
||||||
if not auth.is_authenticated:
|
|
||||||
raise HTTPException(status_code=401, detail="Authentication required")
|
|
||||||
|
|
||||||
# Check permission
|
|
||||||
if not auth.has_permission(resource, action):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail=f"Permission denied: {resource}:{action}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Owner check for thread-specific resources.
|
|
||||||
#
|
|
||||||
# 2.0-rc moved thread metadata into the SQL persistence layer
|
|
||||||
# (``threads_meta`` table). We verify ownership via
|
|
||||||
# ``ThreadMetaStore.check_access``: it returns True for
|
|
||||||
# missing rows (untracked legacy thread) and for rows whose
|
|
||||||
# ``user_id`` is NULL (shared / pre-auth data), so this is
|
|
||||||
# strict-deny rather than strict-allow — only an *existing*
|
|
||||||
# row with a *different* user_id triggers 404.
|
|
||||||
if owner_check:
|
|
||||||
thread_id = kwargs.get("thread_id")
|
|
||||||
if thread_id is None:
|
|
||||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
|
||||||
|
|
||||||
from app.gateway.deps import get_thread_store
|
|
||||||
|
|
||||||
thread_store = get_thread_store(request)
|
|
||||||
allowed = await thread_store.check_access(
|
|
||||||
thread_id,
|
|
||||||
str(auth.user.id),
|
|
||||||
require_existing=require_existing,
|
|
||||||
)
|
|
||||||
if not allowed:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=404,
|
|
||||||
detail=f"Thread {thread_id} not found",
|
|
||||||
)
|
|
||||||
|
|
||||||
return await func(*args, **kwargs)
|
|
||||||
|
|
||||||
return wrapper
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
@@ -1,113 +0,0 @@
|
|||||||
"""CSRF protection middleware for FastAPI.
|
|
||||||
|
|
||||||
Per RFC-001:
|
|
||||||
State-changing operations require CSRF protection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
from fastapi import Request, Response
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.responses import JSONResponse
|
|
||||||
from starlette.types import ASGIApp
|
|
||||||
|
|
||||||
CSRF_COOKIE_NAME = "csrf_token"
|
|
||||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
|
||||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
|
||||||
|
|
||||||
|
|
||||||
def is_secure_request(request: Request) -> bool:
|
|
||||||
"""Detect whether the original client request was made over HTTPS."""
|
|
||||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
|
||||||
|
|
||||||
|
|
||||||
def generate_csrf_token() -> str:
|
|
||||||
"""Generate a secure random CSRF token."""
|
|
||||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
|
||||||
|
|
||||||
|
|
||||||
def should_check_csrf(request: Request) -> bool:
|
|
||||||
"""Determine if a request needs CSRF validation.
|
|
||||||
|
|
||||||
CSRF is checked for state-changing methods (POST, PUT, DELETE, PATCH).
|
|
||||||
GET, HEAD, OPTIONS, and TRACE are exempt per RFC 7231.
|
|
||||||
"""
|
|
||||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
|
||||||
return False
|
|
||||||
|
|
||||||
path = request.url.path.rstrip("/")
|
|
||||||
# Exempt /api/v1/auth/me endpoint
|
|
||||||
if path == "/api/v1/auth/me":
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
|
|
||||||
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"/api/v1/auth/login/local",
|
|
||||||
"/api/v1/auth/logout",
|
|
||||||
"/api/v1/auth/register",
|
|
||||||
"/api/v1/auth/initialize",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def is_auth_endpoint(request: Request) -> bool:
|
|
||||||
"""Check if the request is to an auth endpoint.
|
|
||||||
|
|
||||||
Auth endpoints don't need CSRF validation on first call (no token).
|
|
||||||
"""
|
|
||||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
|
||||||
|
|
||||||
|
|
||||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Middleware that implements CSRF protection using Double Submit Cookie pattern."""
|
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp) -> None:
|
|
||||||
super().__init__(app)
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
||||||
_is_auth = is_auth_endpoint(request)
|
|
||||||
|
|
||||||
if should_check_csrf(request) and not _is_auth:
|
|
||||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
|
||||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
|
||||||
|
|
||||||
if not cookie_token or not header_token:
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=403,
|
|
||||||
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
|
|
||||||
)
|
|
||||||
|
|
||||||
if not secrets.compare_digest(cookie_token, header_token):
|
|
||||||
return JSONResponse(
|
|
||||||
status_code=403,
|
|
||||||
content={"detail": "CSRF token mismatch."},
|
|
||||||
)
|
|
||||||
|
|
||||||
response = await call_next(request)
|
|
||||||
|
|
||||||
# For auth endpoints that set up session, also set CSRF cookie
|
|
||||||
if _is_auth and request.method == "POST":
|
|
||||||
# Generate a new CSRF token for the session
|
|
||||||
csrf_token = generate_csrf_token()
|
|
||||||
is_https = is_secure_request(request)
|
|
||||||
response.set_cookie(
|
|
||||||
key=CSRF_COOKIE_NAME,
|
|
||||||
value=csrf_token,
|
|
||||||
httponly=False, # Must be JS-readable for Double Submit Cookie pattern
|
|
||||||
secure=is_https,
|
|
||||||
samesite="strict",
|
|
||||||
)
|
|
||||||
|
|
||||||
return response
|
|
||||||
|
|
||||||
|
|
||||||
def get_csrf_token(request: Request) -> str | None:
|
|
||||||
"""Get the CSRF token from the current request's cookies.
|
|
||||||
|
|
||||||
This is useful for server-side rendering where you need to embed
|
|
||||||
token in forms or headers.
|
|
||||||
"""
|
|
||||||
return request.cookies.get(CSRF_COOKIE_NAME)
|
|
||||||
@@ -1,234 +0,0 @@
|
|||||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
|
||||||
|
|
||||||
**Getters** (used by routers): raise 503 when a required dependency is
|
|
||||||
missing, except ``get_store`` which returns ``None``.
|
|
||||||
|
|
||||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator, Callable
|
|
||||||
from contextlib import AsyncExitStack, asynccontextmanager
|
|
||||||
from typing import TYPE_CHECKING, TypeVar, cast
|
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
|
||||||
from langgraph.types import Checkpointer
|
|
||||||
|
|
||||||
from deerflow.persistence.feedback import FeedbackRepository
|
|
||||||
from deerflow.runtime import RunContext, RunManager, StreamBridge
|
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
|
||||||
from deerflow.runtime.runs.store.base import RunStore
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
||||||
"""Bootstrap and tear down all LangGraph runtime singletons.
|
|
||||||
|
|
||||||
Usage in ``app.py``::
|
|
||||||
|
|
||||||
async with langgraph_runtime(app):
|
|
||||||
yield
|
|
||||||
"""
|
|
||||||
from deerflow.config import get_app_config
|
|
||||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
|
||||||
from deerflow.runtime import make_store, make_stream_bridge
|
|
||||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
|
||||||
from deerflow.runtime.events.store import make_run_event_store
|
|
||||||
|
|
||||||
async with AsyncExitStack() as stack:
|
|
||||||
app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge())
|
|
||||||
|
|
||||||
# Initialize persistence engine BEFORE checkpointer so that
|
|
||||||
# auto-create-database logic runs first (postgres backend).
|
|
||||||
config = get_app_config()
|
|
||||||
await init_engine_from_config(config.database)
|
|
||||||
|
|
||||||
app.state.checkpointer = await stack.enter_async_context(make_checkpointer())
|
|
||||||
app.state.store = await stack.enter_async_context(make_store())
|
|
||||||
|
|
||||||
# Initialize repositories — one get_session_factory() call for all.
|
|
||||||
sf = get_session_factory()
|
|
||||||
if sf is not None:
|
|
||||||
from deerflow.persistence.feedback import FeedbackRepository
|
|
||||||
from deerflow.persistence.run import RunRepository
|
|
||||||
|
|
||||||
app.state.run_store = RunRepository(sf)
|
|
||||||
app.state.feedback_repo = FeedbackRepository(sf)
|
|
||||||
else:
|
|
||||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
|
||||||
|
|
||||||
app.state.run_store = MemoryRunStore()
|
|
||||||
app.state.feedback_repo = None
|
|
||||||
|
|
||||||
from deerflow.persistence.thread_meta import make_thread_store
|
|
||||||
|
|
||||||
app.state.thread_store = make_thread_store(sf, app.state.store)
|
|
||||||
|
|
||||||
# Run event store (has its own factory with config-driven backend selection)
|
|
||||||
run_events_config = getattr(config, "run_events", None)
|
|
||||||
app.state.run_event_store = make_run_event_store(run_events_config)
|
|
||||||
|
|
||||||
# RunManager with store backing for persistence
|
|
||||||
app.state.run_manager = RunManager(store=app.state.run_store)
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
await close_engine()
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Getters – called by routers per-request
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _require(attr: str, label: str) -> Callable[[Request], T]:
|
|
||||||
"""Create a FastAPI dependency that returns ``app.state.<attr>`` or 503."""
|
|
||||||
|
|
||||||
def dep(request: Request) -> T:
|
|
||||||
val = getattr(request.app.state, attr, None)
|
|
||||||
if val is None:
|
|
||||||
raise HTTPException(status_code=503, detail=f"{label} not available")
|
|
||||||
return cast(T, val)
|
|
||||||
|
|
||||||
dep.__name__ = dep.__qualname__ = f"get_{attr}"
|
|
||||||
return dep
|
|
||||||
|
|
||||||
|
|
||||||
get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge")
|
|
||||||
get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager")
|
|
||||||
get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer")
|
|
||||||
get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store")
|
|
||||||
get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback")
|
|
||||||
get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store")
|
|
||||||
|
|
||||||
|
|
||||||
def get_store(request: Request):
|
|
||||||
"""Return the global store (may be ``None`` if not configured)."""
|
|
||||||
return getattr(request.app.state, "store", None)
|
|
||||||
|
|
||||||
|
|
||||||
def get_thread_store(request: Request) -> ThreadMetaStore:
|
|
||||||
"""Return the thread metadata store (SQL or memory-backed)."""
|
|
||||||
val = getattr(request.app.state, "thread_store", None)
|
|
||||||
if val is None:
|
|
||||||
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
|
||||||
return val
|
|
||||||
|
|
||||||
|
|
||||||
def get_run_context(request: Request) -> RunContext:
|
|
||||||
"""Build a :class:`RunContext` from ``app.state`` singletons.
|
|
||||||
|
|
||||||
Returns a *base* context with infrastructure dependencies.
|
|
||||||
"""
|
|
||||||
from deerflow.config import get_app_config
|
|
||||||
|
|
||||||
return RunContext(
|
|
||||||
checkpointer=get_checkpointer(request),
|
|
||||||
store=get_store(request),
|
|
||||||
event_store=get_run_event_store(request),
|
|
||||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
|
||||||
thread_store=get_thread_store(request),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Auth helpers (used by authz.py and auth middleware)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Cached singletons to avoid repeated instantiation per request
|
|
||||||
_cached_local_provider: LocalAuthProvider | None = None
|
|
||||||
_cached_repo: SQLiteUserRepository | None = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_local_provider() -> LocalAuthProvider:
|
|
||||||
"""Get or create the cached LocalAuthProvider singleton.
|
|
||||||
|
|
||||||
Must be called after ``init_engine_from_config()`` — the shared
|
|
||||||
session factory is required to construct the user repository.
|
|
||||||
"""
|
|
||||||
global _cached_local_provider, _cached_repo
|
|
||||||
if _cached_repo is None:
|
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
|
||||||
from deerflow.persistence.engine import get_session_factory
|
|
||||||
|
|
||||||
sf = get_session_factory()
|
|
||||||
if sf is None:
|
|
||||||
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
|
|
||||||
_cached_repo = SQLiteUserRepository(sf)
|
|
||||||
if _cached_local_provider is None:
|
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
|
||||||
|
|
||||||
_cached_local_provider = LocalAuthProvider(repository=_cached_repo)
|
|
||||||
return _cached_local_provider
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user_from_request(request: Request):
|
|
||||||
"""Get the current authenticated user from the request cookie.
|
|
||||||
|
|
||||||
Raises HTTPException 401 if not authenticated.
|
|
||||||
"""
|
|
||||||
from app.gateway.auth import decode_token
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
|
||||||
|
|
||||||
access_token = request.cookies.get("access_token")
|
|
||||||
if not access_token:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = decode_token(access_token)
|
|
||||||
if isinstance(payload, TokenError):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
provider = get_local_provider()
|
|
||||||
user = await provider.get_user(payload.sub)
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Token version mismatch → password was changed, token is stale
|
|
||||||
if user.token_version != payload.ver:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
return user
|
|
||||||
|
|
||||||
|
|
||||||
async def get_optional_user_from_request(request: Request):
|
|
||||||
"""Get optional authenticated user from request.
|
|
||||||
|
|
||||||
Returns None if not authenticated.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
return await get_current_user_from_request(request)
|
|
||||||
except HTTPException:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> str | None:
|
|
||||||
"""Extract user_id from request cookie, or None if not authenticated.
|
|
||||||
|
|
||||||
Thin adapter that returns the string id for callers that only need
|
|
||||||
identification (e.g., ``feedback.py``). Full-user callers should use
|
|
||||||
``get_current_user_from_request`` or ``get_optional_user_from_request``.
|
|
||||||
"""
|
|
||||||
user = await get_optional_user_from_request(request)
|
|
||||||
return str(user.id) if user else None
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
"""LangGraph Server auth handler — shares JWT logic with Gateway.
|
|
||||||
|
|
||||||
Loaded by LangGraph Server via langgraph.json ``auth.path``.
|
|
||||||
Reuses the same ``decode_token`` / ``get_auth_config`` as Gateway,
|
|
||||||
so both modes validate tokens with the same secret and rules.
|
|
||||||
|
|
||||||
Two layers:
|
|
||||||
1. @auth.authenticate — validates JWT cookie, extracts user_id,
|
|
||||||
and enforces CSRF on state-changing methods (POST/PUT/DELETE/PATCH)
|
|
||||||
2. @auth.on — returns metadata filter so each user only sees own threads
|
|
||||||
"""
|
|
||||||
|
|
||||||
import secrets
|
|
||||||
|
|
||||||
from langgraph_sdk import Auth
|
|
||||||
|
|
||||||
from app.gateway.auth.errors import TokenError
|
|
||||||
from app.gateway.auth.jwt import decode_token
|
|
||||||
from app.gateway.deps import get_local_provider
|
|
||||||
|
|
||||||
auth = Auth()
|
|
||||||
|
|
||||||
# Methods that require CSRF validation (state-changing per RFC 7231).
|
|
||||||
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
|
||||||
|
|
||||||
|
|
||||||
def _check_csrf(request) -> None:
|
|
||||||
"""Enforce Double Submit Cookie CSRF check for state-changing requests.
|
|
||||||
|
|
||||||
Mirrors Gateway's CSRFMiddleware logic so that LangGraph routes
|
|
||||||
proxied directly by nginx have the same CSRF protection.
|
|
||||||
"""
|
|
||||||
method = getattr(request, "method", "") or ""
|
|
||||||
if method.upper() not in _CSRF_METHODS:
|
|
||||||
return
|
|
||||||
|
|
||||||
cookie_token = request.cookies.get("csrf_token")
|
|
||||||
header_token = request.headers.get("x-csrf-token")
|
|
||||||
|
|
||||||
if not cookie_token or not header_token:
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail="CSRF token missing. Include X-CSRF-Token header.",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not secrets.compare_digest(cookie_token, header_token):
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=403,
|
|
||||||
detail="CSRF token mismatch.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@auth.authenticate
|
|
||||||
async def authenticate(request):
|
|
||||||
"""Validate the session cookie, decode JWT, and check token_version.
|
|
||||||
|
|
||||||
Same validation chain as Gateway's get_current_user_from_request:
|
|
||||||
cookie → decode JWT → DB lookup → token_version match
|
|
||||||
Also enforces CSRF on state-changing methods.
|
|
||||||
"""
|
|
||||||
# CSRF check before authentication so forged cross-site requests
|
|
||||||
# are rejected early, even if the cookie carries a valid JWT.
|
|
||||||
_check_csrf(request)
|
|
||||||
|
|
||||||
token = request.cookies.get("access_token")
|
|
||||||
if not token:
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Not authenticated",
|
|
||||||
)
|
|
||||||
|
|
||||||
payload = decode_token(token)
|
|
||||||
if isinstance(payload, TokenError):
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail=f"Token error: {payload.value}",
|
|
||||||
)
|
|
||||||
|
|
||||||
user = await get_local_provider().get_user(payload.sub)
|
|
||||||
if user is None:
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
if user.token_version != payload.ver:
|
|
||||||
raise Auth.exceptions.HTTPException(
|
|
||||||
status_code=401,
|
|
||||||
detail="Token revoked (password changed)",
|
|
||||||
)
|
|
||||||
|
|
||||||
return payload.sub
|
|
||||||
|
|
||||||
|
|
||||||
@auth.on
|
|
||||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
|
||||||
"""Inject user_id metadata on writes; filter by user_id on reads.
|
|
||||||
|
|
||||||
Gateway stores thread ownership as ``metadata.user_id``.
|
|
||||||
This handler ensures LangGraph Server enforces the same isolation.
|
|
||||||
"""
|
|
||||||
# On create/update: stamp user_id into metadata
|
|
||||||
metadata = value.setdefault("metadata", {})
|
|
||||||
metadata["user_id"] = ctx.user.identity
|
|
||||||
|
|
||||||
# Return filter dict — LangGraph applies it to search/read/delete
|
|
||||||
return {"user_id": ctx.user.identity}
|
|
||||||
@@ -1,149 +0,0 @@
|
|||||||
"""Assistants compatibility endpoints.
|
|
||||||
|
|
||||||
Provides LangGraph Platform-compatible assistants API backed by the
|
|
||||||
``langgraph.json`` graph registry and ``config.yaml`` agent definitions.
|
|
||||||
|
|
||||||
This is a minimal stub that satisfies the ``useStream`` React hook's
|
|
||||||
initialization requirements (``assistants.search()`` and ``assistants.get()``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import UTC, datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
router = APIRouter(prefix="/api/assistants", tags=["assistants-compat"])
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantResponse(BaseModel):
|
|
||||||
assistant_id: str
|
|
||||||
graph_id: str
|
|
||||||
name: str
|
|
||||||
config: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
description: str | None = None
|
|
||||||
created_at: str = ""
|
|
||||||
updated_at: str = ""
|
|
||||||
version: int = 1
|
|
||||||
|
|
||||||
|
|
||||||
class AssistantSearchRequest(BaseModel):
|
|
||||||
graph_id: str | None = None
|
|
||||||
name: str | None = None
|
|
||||||
metadata: dict[str, Any] | None = None
|
|
||||||
limit: int = 10
|
|
||||||
offset: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
def _get_default_assistant() -> AssistantResponse:
|
|
||||||
"""Return the default lead_agent assistant."""
|
|
||||||
now = datetime.now(UTC).isoformat()
|
|
||||||
return AssistantResponse(
|
|
||||||
assistant_id="lead_agent",
|
|
||||||
graph_id="lead_agent",
|
|
||||||
name="lead_agent",
|
|
||||||
config={},
|
|
||||||
metadata={"created_by": "system"},
|
|
||||||
description="DeerFlow lead agent",
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
version=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _list_assistants() -> list[AssistantResponse]:
|
|
||||||
"""List all available assistants from config."""
|
|
||||||
assistants = [_get_default_assistant()]
|
|
||||||
|
|
||||||
# Also include custom agents from config.yaml agents directory
|
|
||||||
try:
|
|
||||||
from deerflow.config.agents_config import list_custom_agents
|
|
||||||
|
|
||||||
for agent_cfg in list_custom_agents():
|
|
||||||
now = datetime.now(UTC).isoformat()
|
|
||||||
assistants.append(
|
|
||||||
AssistantResponse(
|
|
||||||
assistant_id=agent_cfg.name,
|
|
||||||
graph_id="lead_agent", # All agents use the same graph
|
|
||||||
name=agent_cfg.name,
|
|
||||||
config={},
|
|
||||||
metadata={"created_by": "user"},
|
|
||||||
description=agent_cfg.description or "",
|
|
||||||
created_at=now,
|
|
||||||
updated_at=now,
|
|
||||||
version=1,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.debug("Could not load custom agents for assistants list")
|
|
||||||
|
|
||||||
return assistants
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/search", response_model=list[AssistantResponse])
|
|
||||||
async def search_assistants(body: AssistantSearchRequest | None = None) -> list[AssistantResponse]:
|
|
||||||
"""Search assistants.
|
|
||||||
|
|
||||||
Returns all registered assistants (lead_agent + custom agents from config).
|
|
||||||
"""
|
|
||||||
assistants = _list_assistants()
|
|
||||||
|
|
||||||
if body and body.graph_id:
|
|
||||||
assistants = [a for a in assistants if a.graph_id == body.graph_id]
|
|
||||||
if body and body.name:
|
|
||||||
assistants = [a for a in assistants if body.name.lower() in a.name.lower()]
|
|
||||||
|
|
||||||
offset = body.offset if body else 0
|
|
||||||
limit = body.limit if body else 10
|
|
||||||
return assistants[offset : offset + limit]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{assistant_id}", response_model=AssistantResponse)
|
|
||||||
async def get_assistant_compat(assistant_id: str) -> AssistantResponse:
|
|
||||||
"""Get an assistant by ID."""
|
|
||||||
for a in _list_assistants():
|
|
||||||
if a.assistant_id == assistant_id:
|
|
||||||
return a
|
|
||||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{assistant_id}/graph")
|
|
||||||
async def get_assistant_graph(assistant_id: str) -> dict:
|
|
||||||
"""Get the graph structure for an assistant.
|
|
||||||
|
|
||||||
Returns a minimal graph description. Full graph introspection is
|
|
||||||
not supported in the Gateway — this stub satisfies SDK validation.
|
|
||||||
"""
|
|
||||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
|
||||||
if not found:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"graph_id": "lead_agent",
|
|
||||||
"nodes": [],
|
|
||||||
"edges": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{assistant_id}/schemas")
|
|
||||||
async def get_assistant_schemas(assistant_id: str) -> dict:
|
|
||||||
"""Get JSON schemas for an assistant's input/output/state.
|
|
||||||
|
|
||||||
Returns empty schemas — full introspection not supported in Gateway.
|
|
||||||
"""
|
|
||||||
found = any(a.assistant_id == assistant_id for a in _list_assistants())
|
|
||||||
if not found:
|
|
||||||
raise HTTPException(status_code=404, detail=f"Assistant {assistant_id} not found")
|
|
||||||
|
|
||||||
return {
|
|
||||||
"graph_id": "lead_agent",
|
|
||||||
"input_schema": {},
|
|
||||||
"output_schema": {},
|
|
||||||
"state_schema": {},
|
|
||||||
"config_schema": {},
|
|
||||||
}
|
|
||||||
@@ -1,459 +0,0 @@
|
|||||||
"""Authentication endpoints."""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from ipaddress import ip_address, ip_network
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
|
||||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
|
||||||
|
|
||||||
from app.gateway.auth import (
|
|
||||||
UserResponse,
|
|
||||||
create_access_token,
|
|
||||||
)
|
|
||||||
from app.gateway.auth.config import get_auth_config
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse
|
|
||||||
from app.gateway.csrf_middleware import is_secure_request
|
|
||||||
from app.gateway.deps import get_current_user_from_request, get_local_provider
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request/Response Models ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class LoginResponse(BaseModel):
|
|
||||||
"""Response model for login — token only lives in HttpOnly cookie."""
|
|
||||||
|
|
||||||
expires_in: int # seconds
|
|
||||||
needs_setup: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
# Top common-password blocklist. Drawn from the public SecLists "10k worst
|
|
||||||
# passwords" set, lowercased + length>=8 only (shorter ones already fail
|
|
||||||
# the min_length check). Kept tight on purpose: this is the **lower bound**
|
|
||||||
# defense, not a full HIBP / passlib check, and runs in-process per request.
|
|
||||||
_COMMON_PASSWORDS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"password",
|
|
||||||
"password1",
|
|
||||||
"password12",
|
|
||||||
"password123",
|
|
||||||
"password1234",
|
|
||||||
"12345678",
|
|
||||||
"123456789",
|
|
||||||
"1234567890",
|
|
||||||
"qwerty12",
|
|
||||||
"qwertyui",
|
|
||||||
"qwerty123",
|
|
||||||
"abc12345",
|
|
||||||
"abcd1234",
|
|
||||||
"iloveyou",
|
|
||||||
"letmein1",
|
|
||||||
"welcome1",
|
|
||||||
"welcome123",
|
|
||||||
"admin123",
|
|
||||||
"administrator",
|
|
||||||
"passw0rd",
|
|
||||||
"p@ssw0rd",
|
|
||||||
"monkey12",
|
|
||||||
"trustno1",
|
|
||||||
"sunshine",
|
|
||||||
"princess",
|
|
||||||
"football",
|
|
||||||
"baseball",
|
|
||||||
"superman",
|
|
||||||
"batman123",
|
|
||||||
"starwars",
|
|
||||||
"dragon123",
|
|
||||||
"master123",
|
|
||||||
"shadow12",
|
|
||||||
"michael1",
|
|
||||||
"jennifer",
|
|
||||||
"computer",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _password_is_common(password: str) -> bool:
|
|
||||||
"""Case-insensitive blocklist check.
|
|
||||||
|
|
||||||
Lowercases the input so trivial mutations like ``Password`` /
|
|
||||||
``PASSWORD`` are also rejected. Does not normalize digit substitutions
|
|
||||||
(``p@ssw0rd`` is included as a literal entry instead) — keeping the
|
|
||||||
rule cheap and predictable.
|
|
||||||
"""
|
|
||||||
return password.lower() in _COMMON_PASSWORDS
|
|
||||||
|
|
||||||
|
|
||||||
def _validate_strong_password(value: str) -> str:
|
|
||||||
"""Pydantic field-validator body shared by Register + ChangePassword.
|
|
||||||
|
|
||||||
Constraint = function, not type-level mixin. The two request models
|
|
||||||
have no "is-a" relationship; they only share the password-strength
|
|
||||||
rule. Lifting it into a free function lets each model bind it via
|
|
||||||
``@field_validator(field_name)`` without inheritance gymnastics.
|
|
||||||
"""
|
|
||||||
if _password_is_common(value):
|
|
||||||
raise ValueError("Password is too common; choose a stronger password.")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class RegisterRequest(BaseModel):
|
|
||||||
"""Request model for user registration."""
|
|
||||||
|
|
||||||
email: EmailStr
|
|
||||||
password: str = Field(..., min_length=8)
|
|
||||||
|
|
||||||
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
|
||||||
|
|
||||||
|
|
||||||
class ChangePasswordRequest(BaseModel):
|
|
||||||
"""Request model for password change (also handles setup flow)."""
|
|
||||||
|
|
||||||
current_password: str
|
|
||||||
new_password: str = Field(..., min_length=8)
|
|
||||||
new_email: EmailStr | None = None
|
|
||||||
|
|
||||||
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
|
||||||
|
|
||||||
|
|
||||||
class MessageResponse(BaseModel):
|
|
||||||
"""Generic message response."""
|
|
||||||
|
|
||||||
message: str
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
|
||||||
"""Set the access_token HttpOnly cookie on the response."""
|
|
||||||
config = get_auth_config()
|
|
||||||
is_https = is_secure_request(request)
|
|
||||||
response.set_cookie(
|
|
||||||
key="access_token",
|
|
||||||
value=token,
|
|
||||||
httponly=True,
|
|
||||||
secure=is_https,
|
|
||||||
samesite="lax",
|
|
||||||
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Rate Limiting ────────────────────────────────────────────────────────
|
|
||||||
# In-process dict — not shared across workers. Sufficient for single-worker deployments.
|
|
||||||
|
|
||||||
_MAX_LOGIN_ATTEMPTS = 5
|
|
||||||
_LOCKOUT_SECONDS = 300 # 5 minutes
|
|
||||||
|
|
||||||
# ip → (fail_count, lock_until_timestamp)
|
|
||||||
_login_attempts: dict[str, tuple[int, float]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def _trusted_proxies() -> list:
|
|
||||||
"""Parse ``AUTH_TRUSTED_PROXIES`` env var into a list of ip_network objects.
|
|
||||||
|
|
||||||
Comma-separated CIDR or single-IP entries. Empty / unset = no proxy is
|
|
||||||
trusted (direct mode). Invalid entries are skipped with a logger warning.
|
|
||||||
Read live so env-var overrides take effect immediately and tests can
|
|
||||||
``monkeypatch.setenv`` without poking a module-level cache.
|
|
||||||
"""
|
|
||||||
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
|
|
||||||
if not raw:
|
|
||||||
return []
|
|
||||||
nets = []
|
|
||||||
for entry in raw.split(","):
|
|
||||||
entry = entry.strip()
|
|
||||||
if not entry:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
nets.append(ip_network(entry, strict=False))
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("AUTH_TRUSTED_PROXIES: ignoring invalid entry %r", entry)
|
|
||||||
return nets
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client_ip(request: Request) -> str:
|
|
||||||
"""Extract the real client IP for rate limiting.
|
|
||||||
|
|
||||||
Trust model:
|
|
||||||
|
|
||||||
- The TCP peer (``request.client.host``) is always the baseline. It is
|
|
||||||
whatever the kernel reports as the connecting socket — unforgeable
|
|
||||||
by the client itself.
|
|
||||||
- ``X-Real-IP`` is **only** honored if the TCP peer is in the
|
|
||||||
``AUTH_TRUSTED_PROXIES`` allowlist (set via env var, comma-separated
|
|
||||||
CIDR or single IPs). When set, the gateway is assumed to be behind a
|
|
||||||
reverse proxy (nginx, Cloudflare, ALB, …) that overwrites
|
|
||||||
``X-Real-IP`` with the original client address.
|
|
||||||
- With no ``AUTH_TRUSTED_PROXIES`` set, ``X-Real-IP`` is silently
|
|
||||||
ignored — closing the bypass where any client could rotate the
|
|
||||||
header to dodge per-IP rate limits in dev / direct-gateway mode.
|
|
||||||
|
|
||||||
``X-Forwarded-For`` is intentionally NOT used because it is naturally
|
|
||||||
client-controlled at the *first* hop and the trust chain is harder to
|
|
||||||
audit per-request.
|
|
||||||
"""
|
|
||||||
peer_host = request.client.host if request.client else None
|
|
||||||
|
|
||||||
trusted = _trusted_proxies()
|
|
||||||
if trusted and peer_host:
|
|
||||||
try:
|
|
||||||
peer_ip = ip_address(peer_host)
|
|
||||||
if any(peer_ip in net for net in trusted):
|
|
||||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
|
||||||
if real_ip:
|
|
||||||
return real_ip
|
|
||||||
except ValueError:
|
|
||||||
# peer_host wasn't a parseable IP (e.g. "unknown") — fall through
|
|
||||||
pass
|
|
||||||
|
|
||||||
return peer_host or "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
def _check_rate_limit(ip: str) -> None:
|
|
||||||
"""Raise 429 if the IP is currently locked out."""
|
|
||||||
record = _login_attempts.get(ip)
|
|
||||||
if record is None:
|
|
||||||
return
|
|
||||||
fail_count, lock_until = record
|
|
||||||
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
|
||||||
if time.time() < lock_until:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=429,
|
|
||||||
detail="Too many login attempts. Try again later.",
|
|
||||||
)
|
|
||||||
del _login_attempts[ip]
|
|
||||||
|
|
||||||
|
|
||||||
_MAX_TRACKED_IPS = 10000
|
|
||||||
|
|
||||||
|
|
||||||
def _record_login_failure(ip: str) -> None:
|
|
||||||
"""Record a failed login attempt for the given IP."""
|
|
||||||
# Evict expired lockouts when dict grows too large
|
|
||||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
|
||||||
now = time.time()
|
|
||||||
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
|
||||||
for k in expired:
|
|
||||||
del _login_attempts[k]
|
|
||||||
# If still too large, evict cheapest-to-lose half: below-threshold
|
|
||||||
# IPs (lock_until=0.0) sort first, then earliest-expiring lockouts.
|
|
||||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
|
||||||
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
|
||||||
for k, _ in by_time[: len(by_time) // 2]:
|
|
||||||
del _login_attempts[k]
|
|
||||||
|
|
||||||
record = _login_attempts.get(ip)
|
|
||||||
if record is None:
|
|
||||||
_login_attempts[ip] = (1, 0.0)
|
|
||||||
else:
|
|
||||||
new_count = record[0] + 1
|
|
||||||
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
|
||||||
_login_attempts[ip] = (new_count, lock_until)
|
|
||||||
|
|
||||||
|
|
||||||
def _record_login_success(ip: str) -> None:
|
|
||||||
"""Clear failure counter for the given IP on successful login."""
|
|
||||||
_login_attempts.pop(ip, None)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login/local", response_model=LoginResponse)
|
|
||||||
async def login_local(
|
|
||||||
request: Request,
|
|
||||||
response: Response,
|
|
||||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
|
||||||
):
|
|
||||||
"""Local email/password login."""
|
|
||||||
client_ip = _get_client_ip(request)
|
|
||||||
_check_rate_limit(client_ip)
|
|
||||||
|
|
||||||
user = await get_local_provider().authenticate({"email": form_data.username, "password": form_data.password})
|
|
||||||
|
|
||||||
if user is None:
|
|
||||||
_record_login_failure(client_ip)
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Incorrect email or password").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
_record_login_success(client_ip)
|
|
||||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
|
||||||
_set_session_cookie(response, token, request)
|
|
||||||
|
|
||||||
return LoginResponse(
|
|
||||||
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
|
||||||
needs_setup=user.needs_setup,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def register(request: Request, response: Response, body: RegisterRequest):
|
|
||||||
"""Register a new user account (always 'user' role).
|
|
||||||
|
|
||||||
Admin is auto-created on first boot. This endpoint creates regular users.
|
|
||||||
Auto-login by setting the session cookie.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="user")
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already registered").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
|
||||||
_set_session_cookie(response, token, request)
|
|
||||||
|
|
||||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/logout", response_model=MessageResponse)
|
|
||||||
async def logout(request: Request, response: Response):
|
|
||||||
"""Logout current user by clearing the cookie."""
|
|
||||||
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
|
||||||
return MessageResponse(message="Successfully logged out")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/change-password", response_model=MessageResponse)
|
|
||||||
async def change_password(request: Request, response: Response, body: ChangePasswordRequest):
|
|
||||||
"""Change password for the currently authenticated user.
|
|
||||||
|
|
||||||
Also handles the first-boot setup flow:
|
|
||||||
- If new_email is provided, updates email (checks uniqueness)
|
|
||||||
- If user.needs_setup is True and new_email is given, clears needs_setup
|
|
||||||
- Always increments token_version to invalidate old sessions
|
|
||||||
- Re-issues session cookie with new token_version
|
|
||||||
"""
|
|
||||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
|
||||||
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
|
|
||||||
if user.password_hash is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="OAuth users cannot change password").model_dump())
|
|
||||||
|
|
||||||
if not await verify_password_async(body.current_password, user.password_hash):
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.INVALID_CREDENTIALS, message="Current password is incorrect").model_dump())
|
|
||||||
|
|
||||||
provider = get_local_provider()
|
|
||||||
|
|
||||||
# Update email if provided
|
|
||||||
if body.new_email is not None:
|
|
||||||
existing = await provider.get_user_by_email(body.new_email)
|
|
||||||
if existing and str(existing.id) != str(user.id):
|
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=AuthErrorResponse(code=AuthErrorCode.EMAIL_ALREADY_EXISTS, message="Email already in use").model_dump())
|
|
||||||
user.email = body.new_email
|
|
||||||
|
|
||||||
# Update password + bump version
|
|
||||||
user.password_hash = await hash_password_async(body.new_password)
|
|
||||||
user.token_version += 1
|
|
||||||
|
|
||||||
# Clear setup flag if this is the setup flow
|
|
||||||
if user.needs_setup and body.new_email is not None:
|
|
||||||
user.needs_setup = False
|
|
||||||
|
|
||||||
await provider.update_user(user)
|
|
||||||
|
|
||||||
# Re-issue cookie with new token_version
|
|
||||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
|
||||||
_set_session_cookie(response, token, request)
|
|
||||||
|
|
||||||
return MessageResponse(message="Password changed successfully")
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserResponse)
|
|
||||||
async def get_me(request: Request):
|
|
||||||
"""Get current authenticated user info."""
|
|
||||||
user = await get_current_user_from_request(request)
|
|
||||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/setup-status")
|
|
||||||
async def setup_status():
|
|
||||||
"""Check if an admin account exists. Returns needs_setup=True when no admin exists."""
|
|
||||||
admin_count = await get_local_provider().count_admin_users()
|
|
||||||
return {"needs_setup": admin_count == 0}
|
|
||||||
|
|
||||||
|
|
||||||
class InitializeAdminRequest(BaseModel):
|
|
||||||
"""Request model for first-boot admin account creation."""
|
|
||||||
|
|
||||||
email: EmailStr
|
|
||||||
password: str = Field(..., min_length=8)
|
|
||||||
|
|
||||||
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def initialize_admin(request: Request, response: Response, body: InitializeAdminRequest):
|
|
||||||
"""Create the first admin account on initial system setup.
|
|
||||||
|
|
||||||
Only callable when no admin exists. Returns 409 Conflict if an admin
|
|
||||||
already exists.
|
|
||||||
|
|
||||||
On success, the admin account is created with ``needs_setup=False`` and
|
|
||||||
the session cookie is set.
|
|
||||||
"""
|
|
||||||
admin_count = await get_local_provider().count_admin_users()
|
|
||||||
if admin_count > 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
user = await get_local_provider().create_user(email=body.email, password=body.password, system_role="admin", needs_setup=False)
|
|
||||||
except ValueError:
|
|
||||||
# DB unique-constraint race: another concurrent request beat us.
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail=AuthErrorResponse(code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED, message="System already initialized").model_dump(),
|
|
||||||
)
|
|
||||||
|
|
||||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
|
||||||
_set_session_cookie(response, token, request)
|
|
||||||
|
|
||||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
|
||||||
|
|
||||||
|
|
||||||
# ── OAuth Endpoints (Future/Placeholder) ─────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/oauth/{provider}")
|
|
||||||
async def oauth_login(provider: str):
|
|
||||||
"""Initiate OAuth login flow.
|
|
||||||
|
|
||||||
Redirects to the OAuth provider's authorization URL.
|
|
||||||
Currently a placeholder - requires OAuth provider implementation.
|
|
||||||
"""
|
|
||||||
if provider not in ["github", "google"]:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Unsupported OAuth provider: {provider}",
|
|
||||||
)
|
|
||||||
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
||||||
detail="OAuth login not yet implemented",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/callback/{provider}")
|
|
||||||
async def oauth_callback(provider: str, code: str, state: str):
|
|
||||||
"""OAuth callback endpoint.
|
|
||||||
|
|
||||||
Handles the OAuth provider's callback after user authorization.
|
|
||||||
Currently a placeholder.
|
|
||||||
"""
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_501_NOT_IMPLEMENTED,
|
|
||||||
detail="OAuth callback not yet implemented",
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user