feat(app): add plugin system with auth plugin and static assets
Add new application structure: - app/main.py - application entry point - app/plugins/ - plugin system with auth plugin: - api/ - REST API endpoints and schemas - authorization/ - auth policies, providers, hooks - domain/ - business logic (service, models, jwt, password) - injection/ - route injection and guards - ops/ - operational utilities - runtime/ - runtime configuration - security/ - middleware, CSRF, dependencies - storage/ - user repositories and models - app/static/ - static assets (scalar.js for API docs) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,21 @@
|
||||
# Auth Plugin
|
||||
|
||||
This package is the future Level 2 auth plugin boundary for DeerFlow.
|
||||
|
||||
Scope:
|
||||
|
||||
- Auth domain logic: config, errors, models, JWT, password hashing, service
|
||||
- Auth adapters: HTTP router, FastAPI dependencies, middleware, LangGraph adapter
|
||||
- Auth storage: user/account models and repositories
|
||||
|
||||
Non-scope:
|
||||
|
||||
- Shared app/container bootstrap
|
||||
- Shared persistence engine/session lifecycle
|
||||
- Generic plugin discovery/registration framework
|
||||
|
||||
Target architecture:
|
||||
|
||||
- The plugin owns its storage definitions and business logic
|
||||
- The plugin reuses the application's shared persistence infrastructure
|
||||
- The gateway only assembles the plugin instead of owning auth logic directly
|
||||
@@ -0,0 +1,14 @@
|
||||
"""Auth plugin package.
|
||||
|
||||
Level 2 plugin goal:
|
||||
- Own auth domain logic
|
||||
- Own auth adapters (router, dependencies, middleware, LangGraph adapter)
|
||||
- Own auth storage definitions
|
||||
- Reuse the application's shared persistence/session infrastructure
|
||||
"""
|
||||
|
||||
from app.plugins.auth.authorization.hooks import build_authz_hooks
|
||||
|
||||
__all__ = [
|
||||
"build_authz_hooks",
|
||||
]
|
||||
@@ -0,0 +1,17 @@
|
||||
"""HTTP API layer for the auth plugin."""
|
||||
|
||||
from app.plugins.auth.api.router import (
|
||||
ChangePasswordRequest,
|
||||
LoginResponse,
|
||||
MessageResponse,
|
||||
RegisterRequest,
|
||||
router,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChangePasswordRequest",
|
||||
"LoginResponse",
|
||||
"MessageResponse",
|
||||
"RegisterRequest",
|
||||
"router",
|
||||
]
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Authentication endpoints for the auth plugin."""
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request, Response, status
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
|
||||
from app.plugins.auth.api.schemas import (
|
||||
ChangePasswordRequest,
|
||||
InitializeAdminRequest,
|
||||
LoginResponse,
|
||||
MessageResponse,
|
||||
RegisterRequest,
|
||||
_check_rate_limit,
|
||||
_get_client_ip,
|
||||
_login_attempts,
|
||||
_record_login_failure,
|
||||
_record_login_success,
|
||||
)
|
||||
from app.plugins.auth.domain.errors import AuthErrorResponse
|
||||
from app.plugins.auth.domain.jwt import create_access_token
|
||||
from app.plugins.auth.domain.models import UserResponse
|
||||
from app.plugins.auth.domain.service import AuthServiceError
|
||||
from app.plugins.auth.runtime.config_state import get_auth_config
|
||||
from app.plugins.auth.security.csrf import is_secure_request
|
||||
from app.plugins.auth.security.dependencies import CurrentAuthService, get_current_user_from_request
|
||||
|
||||
router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
|
||||
def _set_session_cookie(response: Response, token: str, request: Request) -> None:
|
||||
config = get_auth_config()
|
||||
is_https = is_secure_request(request)
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=is_https,
|
||||
samesite="lax",
|
||||
max_age=config.token_expiry_days * 24 * 3600 if is_https else None,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login/local", response_model=LoginResponse)
|
||||
async def login_local(
|
||||
request: Request,
|
||||
response: Response,
|
||||
auth_service: CurrentAuthService,
|
||||
form_data: OAuth2PasswordRequestForm = Depends(),
|
||||
):
|
||||
client_ip = _get_client_ip(request)
|
||||
_check_rate_limit(client_ip)
|
||||
try:
|
||||
user = await auth_service.login_local(form_data.username, form_data.password)
|
||||
except AuthServiceError as exc:
|
||||
_record_login_failure(client_ip)
|
||||
raise HTTPException(
|
||||
status_code=exc.status_code,
|
||||
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
|
||||
) from exc
|
||||
|
||||
_record_login_success(client_ip)
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
return LoginResponse(
|
||||
expires_in=get_auth_config().token_expiry_days * 24 * 3600,
|
||||
needs_setup=user.needs_setup,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/register", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def register(request: Request, response: Response, body: RegisterRequest, auth_service: CurrentAuthService):
|
||||
try:
|
||||
user = await auth_service.register(body.email, body.password)
|
||||
except AuthServiceError as exc:
|
||||
raise HTTPException(
|
||||
status_code=exc.status_code,
|
||||
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
|
||||
) from exc
|
||||
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||
|
||||
|
||||
@router.post("/logout", response_model=MessageResponse)
|
||||
async def logout(request: Request, response: Response):
|
||||
response.delete_cookie(key="access_token", secure=is_secure_request(request), samesite="lax")
|
||||
return MessageResponse(message="Successfully logged out")
|
||||
|
||||
|
||||
@router.post("/change-password", response_model=MessageResponse)
|
||||
async def change_password(
|
||||
request: Request,
|
||||
response: Response,
|
||||
body: ChangePasswordRequest,
|
||||
auth_service: CurrentAuthService,
|
||||
):
|
||||
user = await get_current_user_from_request(request)
|
||||
try:
|
||||
user = await auth_service.change_password(
|
||||
user,
|
||||
current_password=body.current_password,
|
||||
new_password=body.new_password,
|
||||
new_email=body.new_email,
|
||||
)
|
||||
except AuthServiceError as exc:
|
||||
raise HTTPException(
|
||||
status_code=exc.status_code,
|
||||
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
|
||||
) from exc
|
||||
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
return MessageResponse(message="Password changed successfully")
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserResponse)
|
||||
async def get_me(request: Request):
|
||||
user = await get_current_user_from_request(request)
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup)
|
||||
|
||||
|
||||
@router.get("/setup-status")
|
||||
async def setup_status(auth_service: CurrentAuthService):
|
||||
return {"needs_setup": await auth_service.get_setup_status()}
|
||||
|
||||
|
||||
@router.post("/initialize", response_model=UserResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def initialize_admin(
|
||||
request: Request,
|
||||
response: Response,
|
||||
body: InitializeAdminRequest,
|
||||
auth_service: CurrentAuthService,
|
||||
):
|
||||
try:
|
||||
user = await auth_service.initialize_admin(body.email, body.password)
|
||||
except AuthServiceError as exc:
|
||||
raise HTTPException(
|
||||
status_code=exc.status_code,
|
||||
detail=AuthErrorResponse(code=exc.code, message=exc.message).model_dump(),
|
||||
) from exc
|
||||
|
||||
token = create_access_token(str(user.id), token_version=user.token_version)
|
||||
_set_session_cookie(response, token, request)
|
||||
return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role)
|
||||
|
||||
|
||||
@router.get("/oauth/{provider}")
|
||||
async def oauth_login(provider: str):
|
||||
if provider not in ["github", "google"]:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Unsupported OAuth provider: {provider}")
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="OAuth login not yet implemented")
|
||||
|
||||
|
||||
@router.get("/callback/{provider}")
|
||||
async def oauth_callback(provider: str, code: str, state: str):
|
||||
raise HTTPException(status_code=status.HTTP_501_NOT_IMPLEMENTED, detail="OAuth callback not yet implemented")
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChangePasswordRequest",
|
||||
"InitializeAdminRequest",
|
||||
"LoginResponse",
|
||||
"MessageResponse",
|
||||
"RegisterRequest",
|
||||
"_check_rate_limit",
|
||||
"_get_client_ip",
|
||||
"_login_attempts",
|
||||
"_record_login_failure",
|
||||
"_record_login_success",
|
||||
"router",
|
||||
]
|
||||
@@ -0,0 +1,176 @@
|
||||
"""HTTP schemas and request helpers for the auth plugin API."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import time
|
||||
from ipaddress import ip_address, ip_network
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from pydantic import BaseModel, EmailStr, Field, field_validator
|
||||
|
||||
_COMMON_PASSWORDS: frozenset[str] = frozenset(
|
||||
{
|
||||
"password",
|
||||
"password1",
|
||||
"password12",
|
||||
"password123",
|
||||
"password1234",
|
||||
"12345678",
|
||||
"123456789",
|
||||
"1234567890",
|
||||
"qwerty12",
|
||||
"qwertyui",
|
||||
"qwerty123",
|
||||
"abc12345",
|
||||
"abcd1234",
|
||||
"iloveyou",
|
||||
"letmein1",
|
||||
"welcome1",
|
||||
"welcome123",
|
||||
"admin123",
|
||||
"administrator",
|
||||
"passw0rd",
|
||||
"p@ssw0rd",
|
||||
"monkey12",
|
||||
"trustno1",
|
||||
"sunshine",
|
||||
"princess",
|
||||
"football",
|
||||
"baseball",
|
||||
"superman",
|
||||
"batman123",
|
||||
"starwars",
|
||||
"dragon123",
|
||||
"master123",
|
||||
"shadow12",
|
||||
"michael1",
|
||||
"jennifer",
|
||||
"computer",
|
||||
}
|
||||
)
|
||||
_MAX_LOGIN_ATTEMPTS = 5
|
||||
_LOCKOUT_SECONDS = 300
|
||||
_MAX_TRACKED_IPS = 10000
|
||||
_login_attempts: dict[str, tuple[int, float]] = {}
|
||||
|
||||
|
||||
class LoginResponse(BaseModel):
|
||||
expires_in: int
|
||||
needs_setup: bool = False
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||
|
||||
|
||||
class ChangePasswordRequest(BaseModel):
|
||||
current_password: str
|
||||
new_password: str = Field(..., min_length=8)
|
||||
new_email: EmailStr | None = None
|
||||
|
||||
_strong_password = field_validator("new_password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
message: str
|
||||
|
||||
|
||||
class InitializeAdminRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str = Field(..., min_length=8)
|
||||
|
||||
_strong_password = field_validator("password")(classmethod(lambda cls, v: _validate_strong_password(v)))
|
||||
|
||||
|
||||
def _password_is_common(password: str) -> bool:
|
||||
return password.lower() in _COMMON_PASSWORDS
|
||||
|
||||
|
||||
def _validate_strong_password(value: str) -> str:
|
||||
if _password_is_common(value):
|
||||
raise ValueError("Password is too common; choose a stronger password.")
|
||||
return value
|
||||
|
||||
|
||||
def _trusted_proxies() -> list:
|
||||
raw = os.getenv("AUTH_TRUSTED_PROXIES", "").strip()
|
||||
if not raw:
|
||||
return []
|
||||
nets = []
|
||||
for entry in raw.split(","):
|
||||
entry = entry.strip()
|
||||
if not entry:
|
||||
continue
|
||||
try:
|
||||
nets.append(ip_network(entry, strict=False))
|
||||
except ValueError:
|
||||
pass
|
||||
return nets
|
||||
|
||||
|
||||
def _get_client_ip(request: Request) -> str:
|
||||
peer_host = request.client.host if request.client else None
|
||||
trusted = _trusted_proxies()
|
||||
if trusted and peer_host:
|
||||
try:
|
||||
peer_ip = ip_address(peer_host)
|
||||
if any(peer_ip in net for net in trusted):
|
||||
real_ip = request.headers.get("x-real-ip", "").strip()
|
||||
if real_ip:
|
||||
return real_ip
|
||||
except ValueError:
|
||||
pass
|
||||
return peer_host or "unknown"
|
||||
|
||||
|
||||
def _check_rate_limit(ip: str) -> None:
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
return
|
||||
fail_count, lock_until = record
|
||||
if fail_count >= _MAX_LOGIN_ATTEMPTS:
|
||||
if time.time() < lock_until:
|
||||
raise HTTPException(status_code=429, detail="Too many login attempts. Try again later.")
|
||||
del _login_attempts[ip]
|
||||
|
||||
|
||||
def _record_login_failure(ip: str) -> None:
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
now = time.time()
|
||||
expired = [k for k, (c, t) in _login_attempts.items() if c >= _MAX_LOGIN_ATTEMPTS and now >= t]
|
||||
for key in expired:
|
||||
del _login_attempts[key]
|
||||
if len(_login_attempts) >= _MAX_TRACKED_IPS:
|
||||
by_time = sorted(_login_attempts.items(), key=lambda kv: kv[1][1])
|
||||
for key, _ in by_time[: len(by_time) // 2]:
|
||||
del _login_attempts[key]
|
||||
|
||||
record = _login_attempts.get(ip)
|
||||
if record is None:
|
||||
_login_attempts[ip] = (1, 0.0)
|
||||
else:
|
||||
new_count = record[0] + 1
|
||||
lock_until = time.time() + _LOCKOUT_SECONDS if new_count >= _MAX_LOGIN_ATTEMPTS else 0.0
|
||||
_login_attempts[ip] = (new_count, lock_until)
|
||||
|
||||
|
||||
def _record_login_success(ip: str) -> None:
|
||||
_login_attempts.pop(ip, None)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"ChangePasswordRequest",
|
||||
"InitializeAdminRequest",
|
||||
"LoginResponse",
|
||||
"MessageResponse",
|
||||
"RegisterRequest",
|
||||
"_check_rate_limit",
|
||||
"_get_client_ip",
|
||||
"_login_attempts",
|
||||
"_record_login_failure",
|
||||
"_record_login_success",
|
||||
]
|
||||
@@ -0,0 +1,31 @@
|
||||
"""Authorization layer for the auth plugin."""
|
||||
|
||||
from app.plugins.auth.authorization.authentication import get_auth_context
|
||||
from app.plugins.auth.authorization.hooks import (
|
||||
AuthzHooks,
|
||||
build_authz_hooks,
|
||||
build_permission_provider,
|
||||
build_policy_chain_builder,
|
||||
get_authz_hooks,
|
||||
get_default_authz_hooks,
|
||||
)
|
||||
from app.plugins.auth.authorization.types import (
|
||||
AuthContext,
|
||||
Permissions,
|
||||
ALL_PERMISSIONS,
|
||||
)
|
||||
|
||||
_ALL_PERMISSIONS = ALL_PERMISSIONS
|
||||
|
||||
__all__ = [
|
||||
"AuthContext",
|
||||
"AuthzHooks",
|
||||
"Permissions",
|
||||
"_ALL_PERMISSIONS",
|
||||
"build_authz_hooks",
|
||||
"build_permission_provider",
|
||||
"build_policy_chain_builder",
|
||||
"get_auth_context",
|
||||
"get_authz_hooks",
|
||||
"get_default_authz_hooks",
|
||||
]
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Authentication helpers used by auth-plugin authorization decorators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.plugins.auth.authorization.providers import PermissionProvider, default_permission_provider
|
||||
from app.plugins.auth.authorization.types import AuthContext
|
||||
|
||||
|
||||
def get_auth_context(request: Request) -> AuthContext | None:
|
||||
"""Get AuthContext, preferring Starlette-style request.auth."""
|
||||
|
||||
auth = request.scope.get("auth")
|
||||
if isinstance(auth, AuthContext):
|
||||
return auth
|
||||
return getattr(request.state, "auth", None)
|
||||
|
||||
|
||||
def set_auth_context(request: Request, auth_context: AuthContext) -> AuthContext:
|
||||
"""Persist AuthContext on the standard request surfaces."""
|
||||
|
||||
request.scope["auth"] = auth_context
|
||||
request.state.auth = auth_context
|
||||
return auth_context
|
||||
|
||||
|
||||
async def authenticate_request(
|
||||
request: Request,
|
||||
*,
|
||||
permission_provider: PermissionProvider = default_permission_provider,
|
||||
) -> AuthContext:
|
||||
"""Authenticate request and build AuthContext."""
|
||||
|
||||
from app.plugins.auth.security.dependencies import get_optional_user_from_request
|
||||
|
||||
user = await get_optional_user_from_request(request)
|
||||
if user is None:
|
||||
return AuthContext(user=None, permissions=[])
|
||||
return AuthContext(user=user, permissions=permission_provider(user))
|
||||
|
||||
|
||||
__all__ = ["authenticate_request", "get_auth_context", "set_auth_context"]
|
||||
@@ -0,0 +1,84 @@
|
||||
"""Authorization requirement and policy evaluation helpers."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from app.plugins.auth.authorization.policies import require_thread_owner
|
||||
from app.plugins.auth.authorization.types import AuthContext
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PermissionRequirement:
|
||||
"""Authorization requirement for a single route action."""
|
||||
|
||||
resource: str
|
||||
action: str
|
||||
owner_check: bool = False
|
||||
require_existing: bool = False
|
||||
|
||||
@property
|
||||
def permission(self) -> str:
|
||||
return f"{self.resource}:{self.action}"
|
||||
|
||||
|
||||
PolicyEvaluator = Callable[[Request, AuthContext, PermissionRequirement, Mapping[str, Any]], Awaitable[None]]
|
||||
|
||||
|
||||
def ensure_authenticated(auth: AuthContext) -> None:
|
||||
if not auth.is_authenticated:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
|
||||
|
||||
def ensure_capability(auth: AuthContext, requirement: PermissionRequirement) -> None:
|
||||
if not auth.has_permission(requirement.resource, requirement.action):
|
||||
raise HTTPException(status_code=403, detail=f"Permission denied: {requirement.permission}")
|
||||
|
||||
|
||||
async def evaluate_owner_policy(
|
||||
request: Request,
|
||||
auth: AuthContext,
|
||||
requirement: PermissionRequirement,
|
||||
route_params: Mapping[str, Any],
|
||||
) -> None:
|
||||
if not requirement.owner_check:
|
||||
return
|
||||
|
||||
thread_id = route_params.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
await require_thread_owner(
|
||||
request,
|
||||
auth,
|
||||
thread_id=thread_id,
|
||||
require_existing=requirement.require_existing,
|
||||
)
|
||||
|
||||
|
||||
async def evaluate_requirement(
|
||||
request: Request,
|
||||
auth: AuthContext,
|
||||
requirement: PermissionRequirement,
|
||||
route_params: Mapping[str, Any],
|
||||
*,
|
||||
policy_evaluators: tuple[PolicyEvaluator, ...],
|
||||
) -> None:
|
||||
ensure_authenticated(auth)
|
||||
ensure_capability(auth, requirement)
|
||||
for evaluator in policy_evaluators:
|
||||
await evaluator(request, auth, requirement, route_params)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"PermissionRequirement",
|
||||
"PolicyEvaluator",
|
||||
"ensure_authenticated",
|
||||
"ensure_capability",
|
||||
"evaluate_owner_policy",
|
||||
"evaluate_requirement",
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
"""Auth-plugin authz extension hooks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from app.plugins.auth.authorization.providers import PermissionProvider, default_permission_provider
|
||||
from app.plugins.auth.authorization.registry import PolicyChainBuilder, build_default_policy_evaluators
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AuthzHooks:
|
||||
"""Extension hooks for permission and policy resolution."""
|
||||
|
||||
permission_provider: PermissionProvider = default_permission_provider
|
||||
policy_chain_builder: PolicyChainBuilder = build_default_policy_evaluators
|
||||
|
||||
|
||||
DEFAULT_AUTHZ_HOOKS = AuthzHooks()
|
||||
|
||||
|
||||
def get_default_authz_hooks() -> AuthzHooks:
|
||||
return DEFAULT_AUTHZ_HOOKS
|
||||
|
||||
|
||||
def get_authz_hooks(request: Request | Any | None = None) -> AuthzHooks:
|
||||
if request is not None:
|
||||
app = getattr(request, "app", None)
|
||||
state = getattr(app, "state", None)
|
||||
hooks = getattr(state, "authz_hooks", None)
|
||||
if isinstance(hooks, AuthzHooks):
|
||||
return hooks
|
||||
return DEFAULT_AUTHZ_HOOKS
|
||||
|
||||
|
||||
def build_permission_provider() -> PermissionProvider:
|
||||
return default_permission_provider
|
||||
|
||||
|
||||
def build_policy_chain_builder() -> PolicyChainBuilder:
|
||||
return build_default_policy_evaluators
|
||||
|
||||
|
||||
def build_authz_hooks() -> AuthzHooks:
|
||||
return AuthzHooks(
|
||||
permission_provider=build_permission_provider(),
|
||||
policy_chain_builder=build_policy_chain_builder(),
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AuthzHooks",
|
||||
"DEFAULT_AUTHZ_HOOKS",
|
||||
"build_authz_hooks",
|
||||
"build_permission_provider",
|
||||
"build_policy_chain_builder",
|
||||
"get_authz_hooks",
|
||||
"get_default_authz_hooks",
|
||||
]
|
||||
@@ -0,0 +1,101 @@
|
||||
"""Authorization policies for resource ownership and access checks."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from app.plugins.auth.authorization.types import AuthContext
|
||||
|
||||
|
||||
def _get_thread_owner_id(thread_meta: Any) -> str | None:
|
||||
owner_id = getattr(thread_meta, "user_id", None)
|
||||
if owner_id is not None:
|
||||
return str(owner_id)
|
||||
|
||||
metadata = getattr(thread_meta, "metadata", None) or {}
|
||||
metadata_owner_id = metadata.get("user_id")
|
||||
if metadata_owner_id is not None:
|
||||
return str(metadata_owner_id)
|
||||
return None
|
||||
|
||||
|
||||
async def _thread_exists_via_legacy_sources(request: Request, auth: AuthContext, *, thread_id: str) -> bool:
|
||||
from app.gateway.dependencies.repositories import get_run_repository
|
||||
|
||||
principal_id = auth.principal_id
|
||||
run_store = get_run_repository(request)
|
||||
runs = await run_store.list_by_thread(
|
||||
thread_id,
|
||||
limit=1,
|
||||
user_id=principal_id,
|
||||
)
|
||||
if runs:
|
||||
return True
|
||||
|
||||
checkpointer = getattr(request.app.state, "checkpointer", None)
|
||||
if checkpointer is None:
|
||||
return False
|
||||
|
||||
checkpoint_tuple = await checkpointer.aget_tuple(
|
||||
{"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||
)
|
||||
return checkpoint_tuple is not None
|
||||
|
||||
|
||||
async def require_thread_owner(
|
||||
request: Request,
|
||||
auth: AuthContext,
|
||||
*,
|
||||
thread_id: str,
|
||||
require_existing: bool,
|
||||
) -> None:
|
||||
"""Ensure the current user owns the thread referenced by ``thread_id``."""
|
||||
|
||||
from app.gateway.dependencies.repositories import get_thread_meta_repository
|
||||
|
||||
thread_repo = get_thread_meta_repository(request)
|
||||
thread_meta = await thread_repo.get_thread_meta(thread_id)
|
||||
if thread_meta is None:
|
||||
allowed = not require_existing
|
||||
if not allowed:
|
||||
allowed = await _thread_exists_via_legacy_sources(request, auth, thread_id=thread_id)
|
||||
else:
|
||||
owner_id = _get_thread_owner_id(thread_meta)
|
||||
allowed = owner_id in (None, str(auth.user.id))
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Thread {thread_id} not found",
|
||||
)
|
||||
|
||||
|
||||
async def require_run_owner(
|
||||
request: Request,
|
||||
auth: AuthContext,
|
||||
*,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
require_existing: bool,
|
||||
) -> None:
|
||||
"""Ensure the current user owns the run referenced by ``run_id``."""
|
||||
|
||||
from app.gateway.dependencies import get_run_repository
|
||||
|
||||
run_store = get_run_repository(request)
|
||||
run = await run_store.get(run_id)
|
||||
if run is None:
|
||||
allowed = not require_existing
|
||||
else:
|
||||
allowed = run.get("thread_id") == thread_id
|
||||
|
||||
if not allowed:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"Run {run_id} not found",
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["require_run_owner", "require_thread_owner"]
|
||||
@@ -0,0 +1,18 @@
|
||||
"""Default permission provider hooks for auth-plugin authorization."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from app.plugins.auth.authorization.types import ALL_PERMISSIONS
|
||||
|
||||
PermissionProvider = Callable[[object], list[str]]
|
||||
|
||||
|
||||
def default_permission_provider(user: object) -> list[str]:
|
||||
"""Return the current static permission set for an authenticated user."""
|
||||
|
||||
return list(ALL_PERMISSIONS)
|
||||
|
||||
|
||||
__all__ = ["PermissionProvider", "default_permission_provider"]
|
||||
@@ -0,0 +1,23 @@
|
||||
"""Registry/build helpers for default authorization evaluators."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.plugins.auth.authorization.authorization import PolicyEvaluator
|
||||
|
||||
|
||||
PolicyChainBuilder = Callable[[], tuple["PolicyEvaluator", ...]]
|
||||
|
||||
|
||||
def build_default_policy_evaluators() -> tuple["PolicyEvaluator", ...]:
|
||||
"""Return the default policy chain for auth-plugin authorization."""
|
||||
|
||||
from app.plugins.auth.authorization.authorization import evaluate_owner_policy
|
||||
|
||||
return (evaluate_owner_policy,)
|
||||
|
||||
|
||||
__all__ = ["PolicyChainBuilder", "build_default_policy_evaluators"]
|
||||
@@ -0,0 +1,67 @@
|
||||
"""Authorization context and capability constants for the auth plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.plugins.auth.domain.models import User
|
||||
|
||||
|
||||
class Permissions:
|
||||
"""Permission constants for resource:action format."""
|
||||
|
||||
THREADS_READ = "threads:read"
|
||||
THREADS_WRITE = "threads:write"
|
||||
THREADS_DELETE = "threads:delete"
|
||||
|
||||
RUNS_CREATE = "runs:create"
|
||||
RUNS_READ = "runs:read"
|
||||
RUNS_CANCEL = "runs:cancel"
|
||||
|
||||
|
||||
class AuthContext:
|
||||
"""Authentication context for the current request."""
|
||||
|
||||
__slots__ = ("user", "permissions")
|
||||
|
||||
def __init__(self, user: User | None = None, permissions: list[str] | None = None):
|
||||
self.user = user
|
||||
self.permissions = permissions or []
|
||||
|
||||
@property
|
||||
def is_authenticated(self) -> bool:
|
||||
return self.user is not None
|
||||
|
||||
@property
|
||||
def principal_id(self) -> str | None:
|
||||
if self.user is None:
|
||||
return None
|
||||
return str(self.user.id)
|
||||
|
||||
@property
|
||||
def capabilities(self) -> tuple[str, ...]:
|
||||
return tuple(self.permissions)
|
||||
|
||||
def has_permission(self, resource: str, action: str) -> bool:
|
||||
return f"{resource}:{action}" in self.permissions
|
||||
|
||||
def require_user(self) -> User:
|
||||
if not self.user:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
return self.user
|
||||
|
||||
|
||||
ALL_PERMISSIONS: list[str] = [
|
||||
Permissions.THREADS_READ,
|
||||
Permissions.THREADS_WRITE,
|
||||
Permissions.THREADS_DELETE,
|
||||
Permissions.RUNS_CREATE,
|
||||
Permissions.RUNS_READ,
|
||||
Permissions.RUNS_CANCEL,
|
||||
]
|
||||
|
||||
|
||||
__all__ = ["ALL_PERMISSIONS", "AuthContext", "Permissions"]
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Domain layer for the auth plugin."""
|
||||
|
||||
from app.plugins.auth.domain.config import AuthConfig, load_auth_config_from_env
|
||||
from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code
|
||||
from app.plugins.auth.domain.jwt import TokenPayload, create_access_token, decode_token
|
||||
from app.plugins.auth.domain.models import User, UserResponse
|
||||
from app.plugins.auth.domain.password import hash_password, hash_password_async, verify_password, verify_password_async
|
||||
from app.plugins.auth.domain.service import AuthService, AuthServiceError
|
||||
|
||||
__all__ = [
|
||||
"AuthConfig",
|
||||
"AuthErrorCode",
|
||||
"AuthErrorResponse",
|
||||
"AuthService",
|
||||
"AuthServiceError",
|
||||
"TokenError",
|
||||
"TokenPayload",
|
||||
"User",
|
||||
"UserResponse",
|
||||
"create_access_token",
|
||||
"decode_token",
|
||||
"hash_password",
|
||||
"hash_password_async",
|
||||
"load_auth_config_from_env",
|
||||
"token_error_to_code",
|
||||
"verify_password",
|
||||
"verify_password_async",
|
||||
]
|
||||
@@ -0,0 +1,42 @@
|
||||
"""Auth configuration schema and environment loader."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""JWT and auth-related configuration."""
|
||||
|
||||
jwt_secret: str = Field(..., description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.")
|
||||
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
||||
oauth_github_client_id: str | None = Field(default=None)
|
||||
oauth_github_client_secret: str | None = Field(default=None)
|
||||
|
||||
|
||||
def load_auth_config_from_env() -> AuthConfig:
|
||||
"""Build an auth config from environment variables."""
|
||||
|
||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||
if not jwt_secret:
|
||||
jwt_secret = secrets.token_urlsafe(32)
|
||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||
logger.warning(
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||
"Sessions will be invalidated on restart. "
|
||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
return AuthConfig(jwt_secret=jwt_secret)
|
||||
|
||||
|
||||
__all__ = ["AuthConfig", "load_auth_config_from_env"]
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Typed error definitions for auth plugin."""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthErrorCode(StrEnum):
|
||||
INVALID_CREDENTIALS = "invalid_credentials"
|
||||
TOKEN_EXPIRED = "token_expired"
|
||||
TOKEN_INVALID = "token_invalid"
|
||||
USER_NOT_FOUND = "user_not_found"
|
||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||
NOT_AUTHENTICATED = "not_authenticated"
|
||||
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
|
||||
|
||||
|
||||
class TokenError(StrEnum):
|
||||
EXPIRED = "expired"
|
||||
INVALID_SIGNATURE = "invalid_signature"
|
||||
MALFORMED = "malformed"
|
||||
|
||||
|
||||
class AuthErrorResponse(BaseModel):
|
||||
code: AuthErrorCode
|
||||
message: str
|
||||
|
||||
|
||||
def token_error_to_code(err: TokenError) -> AuthErrorCode:
|
||||
if err == TokenError.EXPIRED:
|
||||
return AuthErrorCode.TOKEN_EXPIRED
|
||||
return AuthErrorCode.TOKEN_INVALID
|
||||
@@ -0,0 +1,37 @@
|
||||
"""JWT token creation and verification."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.plugins.auth.domain.errors import TokenError
|
||||
from app.plugins.auth.runtime.config_state import get_auth_config
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
sub: str
|
||||
exp: datetime
|
||||
iat: datetime | None = None
|
||||
ver: int = 0
|
||||
|
||||
|
||||
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
|
||||
config = get_auth_config()
|
||||
expiry = expires_delta or timedelta(days=config.token_expiry_days)
|
||||
now = datetime.now(UTC)
|
||||
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
|
||||
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload | TokenError:
|
||||
config = get_auth_config()
|
||||
try:
|
||||
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
|
||||
return TokenPayload(**payload)
|
||||
except jwt.ExpiredSignatureError:
|
||||
return TokenError.EXPIRED
|
||||
except jwt.InvalidSignatureError:
|
||||
return TokenError.INVALID_SIGNATURE
|
||||
except jwt.PyJWTError:
|
||||
return TokenError.MALFORMED
|
||||
@@ -0,0 +1,32 @@
|
||||
"""User Pydantic models for the auth plugin."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, description="Primary key")
|
||||
email: EmailStr = Field(..., description="Unique email address")
|
||||
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
|
||||
system_role: Literal["admin", "user"] = Field(default="user")
|
||||
created_at: datetime = Field(default_factory=_utc_now)
|
||||
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
|
||||
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
||||
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
|
||||
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
system_role: Literal["admin", "user"]
|
||||
needs_setup: bool = False
|
||||
@@ -0,0 +1,21 @@
|
||||
"""Password hashing utilities using bcrypt directly."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
|
||||
async def hash_password_async(password: str) -> str:
|
||||
return await asyncio.to_thread(hash_password, password)
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
|
||||
@@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.plugins.auth.domain.errors import AuthErrorCode
|
||||
from app.plugins.auth.domain.models import User
|
||||
from app.plugins.auth.domain.password import hash_password_async, verify_password_async
|
||||
from app.plugins.auth.storage import DbUserRepository, UserCreate
|
||||
from app.plugins.auth.storage.contracts import User as StoreUser
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AuthServiceError(Exception):
|
||||
code: AuthErrorCode
|
||||
message: str
|
||||
status_code: int
|
||||
|
||||
|
||||
def _to_auth_user(user: StoreUser) -> User:
|
||||
return User(
|
||||
id=UUID(user.id),
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
system_role=user.system_role, # type: ignore[arg-type]
|
||||
created_at=user.created_time,
|
||||
oauth_provider=user.oauth_provider,
|
||||
oauth_id=user.oauth_id,
|
||||
needs_setup=user.needs_setup,
|
||||
token_version=user.token_version,
|
||||
)
|
||||
|
||||
|
||||
def _to_store_user(user: User) -> StoreUser:
|
||||
return StoreUser(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
system_role=user.system_role,
|
||||
oauth_provider=user.oauth_provider,
|
||||
oauth_id=user.oauth_id,
|
||||
needs_setup=user.needs_setup,
|
||||
token_version=user.token_version,
|
||||
created_time=user.created_at,
|
||||
updated_time=None,
|
||||
)
|
||||
|
||||
|
||||
class AuthService:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def login_local(self, email: str, password: str) -> User:
|
||||
async with self._session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
user = await repo.get_user_by_email(email)
|
||||
if user is None or user.password_hash is None:
|
||||
raise AuthServiceError(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Incorrect email or password",
|
||||
status_code=HTTPStatus.UNAUTHORIZED,
|
||||
)
|
||||
if not await verify_password_async(password, user.password_hash):
|
||||
raise AuthServiceError(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Incorrect email or password",
|
||||
status_code=HTTPStatus.UNAUTHORIZED,
|
||||
)
|
||||
return _to_auth_user(user)
|
||||
|
||||
async def register(self, email: str, password: str) -> User:
|
||||
async with self._session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
try:
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password_hash=await hash_password_async(password),
|
||||
system_role="user",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
except ValueError as exc:
|
||||
await session.rollback()
|
||||
raise AuthServiceError(
|
||||
code=AuthErrorCode.EMAIL_ALREADY_EXISTS,
|
||||
message="Email already registered",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
) from exc
|
||||
return _to_auth_user(user)
|
||||
|
||||
async def change_password(
|
||||
self,
|
||||
user: User | StoreUser,
|
||||
*,
|
||||
current_password: str,
|
||||
new_password: str,
|
||||
new_email: str | None = None,
|
||||
) -> User:
|
||||
if user.password_hash is None:
|
||||
raise AuthServiceError(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="OAuth users cannot change password",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
if not await verify_password_async(current_password, user.password_hash):
|
||||
raise AuthServiceError(
|
||||
code=AuthErrorCode.INVALID_CREDENTIALS,
|
||||
message="Current password is incorrect",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
|
||||
async with self._session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
updated_email = user.email
|
||||
if new_email is not None:
|
||||
existing = await repo.get_user_by_email(new_email)
|
||||
if existing and existing.id != str(user.id):
|
||||
raise AuthServiceError(
|
||||
code=AuthErrorCode.EMAIL_ALREADY_EXISTS,
|
||||
message="Email already in use",
|
||||
status_code=HTTPStatus.BAD_REQUEST,
|
||||
)
|
||||
updated_email = new_email
|
||||
|
||||
updated_user = user.model_copy(
|
||||
update={
|
||||
"email": updated_email,
|
||||
"password_hash": await hash_password_async(new_password),
|
||||
"token_version": user.token_version + 1,
|
||||
"needs_setup": False if user.needs_setup and new_email is not None else user.needs_setup,
|
||||
}
|
||||
)
|
||||
|
||||
updated = await repo.update_user(_to_store_user(_to_auth_user(updated_user) if isinstance(updated_user, StoreUser) else updated_user))
|
||||
await session.commit()
|
||||
return _to_auth_user(updated)
|
||||
|
||||
async def get_setup_status(self) -> bool:
|
||||
async with self._session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
admin_count = await repo.count_admin_users()
|
||||
return admin_count == 0
|
||||
|
||||
async def initialize_admin(self, email: str, password: str) -> User:
|
||||
async with self._session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
admin_count = await repo.count_admin_users()
|
||||
if admin_count > 0:
|
||||
raise AuthServiceError(
|
||||
code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED,
|
||||
message="System already initialized",
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
)
|
||||
try:
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
email=email,
|
||||
password_hash=await hash_password_async(password),
|
||||
system_role="admin",
|
||||
needs_setup=False,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
except ValueError as exc:
|
||||
await session.rollback()
|
||||
raise AuthServiceError(
|
||||
code=AuthErrorCode.SYSTEM_ALREADY_INITIALIZED,
|
||||
message="System already initialized",
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
) from exc
|
||||
return _to_auth_user(user)
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Config-driven route authorization injection for the auth plugin."""
|
||||
|
||||
from app.plugins.auth.injection.registry_loader import (
|
||||
RoutePolicyRegistry,
|
||||
RoutePolicySpec,
|
||||
load_route_policy_registry,
|
||||
)
|
||||
from app.plugins.auth.injection.route_injector import install_route_guards
|
||||
from app.plugins.auth.injection.validation import validate_route_policy_registry
|
||||
|
||||
__all__ = [
|
||||
"RoutePolicyRegistry",
|
||||
"RoutePolicySpec",
|
||||
"install_route_guards",
|
||||
"load_route_policy_registry",
|
||||
"validate_route_policy_registry",
|
||||
]
|
||||
@@ -0,0 +1,112 @@
|
||||
"""Load auth route policies from the plugin's YAML registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from starlette.routing import compile_path
|
||||
import yaml
|
||||
|
||||
_POLICY_FILE = Path(__file__).resolve().parents[1] / "route_policies.yaml"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RoutePolicySpec:
|
||||
public: bool = False
|
||||
capability: str | None = None
|
||||
policies: tuple[str, ...] = ()
|
||||
require_existing: bool = True
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RoutePolicyEntry:
|
||||
method: str
|
||||
path: str
|
||||
spec: RoutePolicySpec
|
||||
path_regex: object = field(repr=False)
|
||||
|
||||
def matches_request(self, method: str, path: str) -> bool:
|
||||
if self.method != method.upper():
|
||||
return False
|
||||
return self.path_regex.match(path) is not None
|
||||
|
||||
|
||||
class RoutePolicyRegistry:
|
||||
def __init__(self, entries: list[RoutePolicyEntry]) -> None:
|
||||
self._entries = entries
|
||||
self._specs = {(entry.method, entry.path): entry.spec for entry in entries}
|
||||
|
||||
def get(self, method: str, path_template: str) -> RoutePolicySpec | None:
|
||||
return self._specs.get((method.upper(), path_template))
|
||||
|
||||
def has(self, method: str, path_template: str) -> bool:
|
||||
return (method.upper(), path_template) in self._specs
|
||||
|
||||
def match_request(self, method: str, path: str) -> RoutePolicySpec | None:
|
||||
normalized_method = method.upper()
|
||||
for entry in self._entries:
|
||||
if entry.matches_request(normalized_method, path):
|
||||
return entry.spec
|
||||
return None
|
||||
|
||||
def is_public_request(self, method: str, path: str) -> bool:
|
||||
spec = self.match_request(method, path)
|
||||
return bool(spec and spec.public)
|
||||
|
||||
@property
|
||||
def keys(self) -> set[tuple[str, str]]:
|
||||
return set(self._specs)
|
||||
|
||||
|
||||
def _normalize_methods(item: dict) -> tuple[str, ...]:
|
||||
methods = item.get("methods")
|
||||
if methods is None:
|
||||
methods = [item["method"]]
|
||||
if isinstance(methods, str):
|
||||
methods = [methods]
|
||||
return tuple(str(method).upper() for method in methods)
|
||||
|
||||
|
||||
def _build_spec(item: dict) -> RoutePolicySpec:
|
||||
return RoutePolicySpec(
|
||||
public=bool(item.get("public", False)),
|
||||
capability=item.get("capability"),
|
||||
policies=tuple(item.get("policies", [])),
|
||||
require_existing=bool(item.get("require_existing", True)),
|
||||
)
|
||||
|
||||
|
||||
def load_route_policy_registry() -> RoutePolicyRegistry:
|
||||
payload = yaml.safe_load(_POLICY_FILE.read_text(encoding="utf-8")) or {}
|
||||
raw_routes: list[dict] = []
|
||||
for section, entries in payload.items():
|
||||
if section == "routes":
|
||||
if isinstance(entries, list):
|
||||
raw_routes.extend(entries)
|
||||
continue
|
||||
if not isinstance(entries, list):
|
||||
continue
|
||||
for item in entries:
|
||||
normalized = dict(item)
|
||||
if section == "public":
|
||||
normalized["public"] = True
|
||||
raw_routes.append(normalized)
|
||||
entries: list[RoutePolicyEntry] = []
|
||||
for item in raw_routes:
|
||||
path = str(item["path"])
|
||||
spec = _build_spec(item)
|
||||
path_regex, _, _ = compile_path(path)
|
||||
for method in _normalize_methods(item):
|
||||
entries.append(
|
||||
RoutePolicyEntry(
|
||||
method=method,
|
||||
path=path,
|
||||
spec=spec,
|
||||
path_regex=path_regex,
|
||||
)
|
||||
)
|
||||
return RoutePolicyRegistry(entries)
|
||||
|
||||
|
||||
__all__ = ["RoutePolicyRegistry", "RoutePolicySpec", "load_route_policy_registry"]
|
||||
@@ -0,0 +1,102 @@
|
||||
"""Runtime route guard backed by the auth plugin's route policy registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
|
||||
from app.plugins.auth.authorization.authentication import (
|
||||
authenticate_request,
|
||||
get_auth_context,
|
||||
set_auth_context,
|
||||
)
|
||||
from app.plugins.auth.authorization.authorization import ensure_authenticated
|
||||
from app.plugins.auth.authorization.hooks import get_authz_hooks
|
||||
from app.plugins.auth.authorization.policies import require_run_owner, require_thread_owner
|
||||
from app.plugins.auth.injection.registry_loader import RoutePolicyRegistry, RoutePolicySpec
|
||||
|
||||
PolicyGuard = Callable[[Request, RoutePolicySpec], Awaitable[None]]
|
||||
|
||||
|
||||
async def _check_capability(request: Request, spec: RoutePolicySpec) -> None:
|
||||
if not spec.capability:
|
||||
return
|
||||
|
||||
auth = get_auth_context(request)
|
||||
if auth is None:
|
||||
raise HTTPException(status_code=500, detail="Missing auth context")
|
||||
|
||||
if ":" not in spec.capability:
|
||||
raise RuntimeError(f"Invalid capability format: {spec.capability}")
|
||||
resource, action = spec.capability.split(":", 1)
|
||||
if not auth.has_permission(resource, action):
|
||||
raise HTTPException(status_code=403, detail=f"Permission denied: {spec.capability}")
|
||||
|
||||
|
||||
async def _guard_thread_owner(request: Request, spec: RoutePolicySpec) -> None:
|
||||
auth = get_auth_context(request)
|
||||
if auth is None:
|
||||
raise HTTPException(status_code=500, detail="Missing auth context")
|
||||
thread_id = request.path_params.get("thread_id")
|
||||
if not isinstance(thread_id, str):
|
||||
raise RuntimeError("owner:thread policy requires thread_id path parameter")
|
||||
await require_thread_owner(request, auth, thread_id=thread_id, require_existing=spec.require_existing)
|
||||
|
||||
|
||||
async def _guard_run_owner(request: Request, spec: RoutePolicySpec) -> None:
|
||||
auth = get_auth_context(request)
|
||||
if auth is None:
|
||||
raise HTTPException(status_code=500, detail="Missing auth context")
|
||||
thread_id = request.path_params.get("thread_id")
|
||||
run_id = request.path_params.get("run_id")
|
||||
if not isinstance(thread_id, str) or not isinstance(run_id, str):
|
||||
raise RuntimeError("owner:run policy requires thread_id and run_id path parameters")
|
||||
await require_run_owner(
|
||||
request,
|
||||
auth,
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
require_existing=spec.require_existing,
|
||||
)
|
||||
|
||||
|
||||
_POLICY_GUARDS: dict[str, PolicyGuard] = {
|
||||
"owner:thread": _guard_thread_owner,
|
||||
"owner:run": _guard_run_owner,
|
||||
}
|
||||
|
||||
|
||||
async def enforce_route_policy(request: Request) -> None:
|
||||
registry = getattr(request.app.state, "auth_route_policy_registry", None)
|
||||
if not isinstance(registry, RoutePolicyRegistry):
|
||||
raise RuntimeError("Auth route policy registry is not configured")
|
||||
|
||||
route = request.scope.get("route")
|
||||
path_template = getattr(route, "path", None)
|
||||
if not isinstance(path_template, str):
|
||||
raise RuntimeError("Unable to resolve route path for authorization")
|
||||
|
||||
spec = registry.get(request.method, path_template)
|
||||
if spec is None:
|
||||
raise RuntimeError(f"Missing auth route policy for {request.method} {path_template}")
|
||||
if spec.public:
|
||||
return
|
||||
|
||||
auth = get_auth_context(request)
|
||||
if auth is None:
|
||||
hooks = get_authz_hooks(request)
|
||||
auth = await authenticate_request(request, permission_provider=hooks.permission_provider)
|
||||
set_auth_context(request, auth)
|
||||
|
||||
ensure_authenticated(auth)
|
||||
await _check_capability(request, spec)
|
||||
|
||||
for policy_name in spec.policies:
|
||||
guard = _POLICY_GUARDS.get(policy_name)
|
||||
if guard is None:
|
||||
raise RuntimeError(f"Unknown route policy guard: {policy_name}")
|
||||
await guard(request, spec)
|
||||
|
||||
|
||||
__all__ = ["enforce_route_policy"]
|
||||
@@ -0,0 +1,39 @@
|
||||
"""Inject config-driven auth guards into FastAPI routes."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, FastAPI
|
||||
from fastapi.dependencies.utils import get_dependant, get_flat_dependant, get_parameterless_sub_dependant
|
||||
from fastapi.routing import APIRoute, _should_embed_body_fields, get_body_field, request_response
|
||||
|
||||
from app.plugins.auth.injection.route_guard import enforce_route_policy
|
||||
|
||||
|
||||
def _rebuild_route(route: APIRoute) -> None:
|
||||
route.dependant = get_dependant(path=route.path_format, call=route.endpoint, scope="function")
|
||||
for depends in route.dependencies[::-1]:
|
||||
route.dependant.dependencies.insert(
|
||||
0,
|
||||
get_parameterless_sub_dependant(depends=depends, path=route.path_format),
|
||||
)
|
||||
route._flat_dependant = get_flat_dependant(route.dependant)
|
||||
route._embed_body_fields = _should_embed_body_fields(route._flat_dependant.body_params)
|
||||
route.body_field = get_body_field(
|
||||
flat_dependant=route._flat_dependant,
|
||||
name=route.unique_id,
|
||||
embed_body_fields=route._embed_body_fields,
|
||||
)
|
||||
route.app = request_response(route.get_route_handler())
|
||||
|
||||
|
||||
def install_route_guards(app: FastAPI) -> None:
|
||||
for route in app.routes:
|
||||
if not isinstance(route, APIRoute):
|
||||
continue
|
||||
if any(getattr(dependency, "dependency", None) is enforce_route_policy for dependency in route.dependencies):
|
||||
continue
|
||||
route.dependencies.append(Depends(enforce_route_policy))
|
||||
_rebuild_route(route)
|
||||
|
||||
|
||||
__all__ = ["install_route_guards"]
|
||||
@@ -0,0 +1,38 @@
|
||||
"""Validation helpers for config-driven auth route policies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.routing import APIRoute
|
||||
|
||||
from app.plugins.auth.injection.registry_loader import RoutePolicyRegistry
|
||||
|
||||
_IGNORED_METHODS = frozenset({"HEAD", "OPTIONS"})
|
||||
|
||||
|
||||
def _iter_route_keys(app: FastAPI) -> set[tuple[str, str]]:
|
||||
keys: set[tuple[str, str]] = set()
|
||||
for route in app.routes:
|
||||
if not isinstance(route, APIRoute):
|
||||
continue
|
||||
for method in route.methods:
|
||||
if method in _IGNORED_METHODS:
|
||||
continue
|
||||
keys.add((method, route.path))
|
||||
return keys
|
||||
|
||||
|
||||
def validate_route_policy_registry(app: FastAPI, registry: RoutePolicyRegistry) -> None:
|
||||
route_keys = _iter_route_keys(app)
|
||||
missing = sorted(route_keys - registry.keys)
|
||||
extra = sorted(registry.keys - route_keys)
|
||||
problems: list[str] = []
|
||||
if missing:
|
||||
problems.append("Missing route policy entries:\n" + "\n".join(f" - {method} {path}" for method, path in missing))
|
||||
if extra:
|
||||
problems.append("Unknown route policy entries:\n" + "\n".join(f" - {method} {path}" for method, path in extra))
|
||||
if problems:
|
||||
raise RuntimeError("\n\n".join(problems))
|
||||
|
||||
|
||||
__all__ = ["validate_route_policy_registry"]
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Operational tooling for the auth plugin."""
|
||||
|
||||
from app.plugins.auth.ops.credential_file import write_initial_credentials
|
||||
from app.plugins.auth.ops.reset_admin import main
|
||||
|
||||
__all__ = ["main", "write_initial_credentials"]
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Write initial admin credentials to a restricted file instead of logs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
|
||||
|
||||
|
||||
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
|
||||
target = get_paths().base_dir / _CREDENTIAL_FILENAME
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
content = (
|
||||
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
|
||||
)
|
||||
|
||||
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
fh.write(content)
|
||||
|
||||
return target.resolve()
|
||||
|
||||
|
||||
__all__ = ["write_initial_credentials"]
|
||||
@@ -0,0 +1,74 @@
|
||||
"""CLI tool to reset an admin password."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.plugins.auth.domain.password import hash_password
|
||||
from app.plugins.auth.ops.credential_file import write_initial_credentials
|
||||
from app.plugins.auth.storage import DbUserRepository
|
||||
from app.plugins.auth.storage.models import User as UserModel
|
||||
|
||||
|
||||
async def _run(email: str | None) -> int:
|
||||
from store.persistence import create_persistence
|
||||
|
||||
app_persistence = await create_persistence()
|
||||
await app_persistence.setup()
|
||||
try:
|
||||
if email:
|
||||
async with app_persistence.session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
user = await repo.get_user_by_email(email)
|
||||
else:
|
||||
async with app_persistence.session_factory() as session:
|
||||
stmt = select(UserModel).where(UserModel.system_role == "admin").limit(1)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
user = None
|
||||
else:
|
||||
repo = DbUserRepository(session)
|
||||
user = await repo.get_user_by_id(row.id)
|
||||
|
||||
if user is None:
|
||||
print(f"Error: user '{email}' not found." if email else "Error: no admin user found.", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
updated_user = user.model_copy(
|
||||
update={
|
||||
"password_hash": hash_password(new_password),
|
||||
"token_version": user.token_version + 1,
|
||||
"needs_setup": True,
|
||||
}
|
||||
)
|
||||
async with app_persistence.session_factory() as session:
|
||||
repo = DbUserRepository(session)
|
||||
await repo.update_user(updated_user)
|
||||
await session.commit()
|
||||
|
||||
cred_path = write_initial_credentials(user.email, new_password, label="reset")
|
||||
print(f"Password reset for: {user.email}")
|
||||
print(f"Credentials written to: {cred_path} (mode 0600)")
|
||||
print("Next login will require setup (new email + password).")
|
||||
return 0
|
||||
finally:
|
||||
await app_persistence.aclose()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Reset admin password")
|
||||
parser.add_argument("--email", help="Admin email (default: first admin found)")
|
||||
args = parser.parse_args()
|
||||
|
||||
exit_code = asyncio.run(_run(args.email))
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,18 @@
|
||||
[plugin]
|
||||
name = "auth"
|
||||
summary = "Cookie-based authentication and authorization"
|
||||
version = "0.1.0"
|
||||
description = "Owns DeerFlow authentication, authorization adapters, and auth storage definitions while reusing shared persistence infrastructure."
|
||||
author = "DeerFlow"
|
||||
tags = ["auth", "gateway", "session"]
|
||||
|
||||
[capabilities]
|
||||
router = true
|
||||
middleware = true
|
||||
dependencies = true
|
||||
langgraph_adapter = true
|
||||
storage = true
|
||||
|
||||
[storage]
|
||||
mode = "shared_infrastructure"
|
||||
notes = "This plugin owns its storage definitions and repositories but uses the application's shared engine and session factory."
|
||||
@@ -0,0 +1,204 @@
|
||||
public:
|
||||
- method: POST
|
||||
path: /api/v1/auth/login/local
|
||||
- method: POST
|
||||
path: /api/v1/auth/register
|
||||
- method: POST
|
||||
path: /api/v1/auth/logout
|
||||
- method: GET
|
||||
path: /api/v1/auth/setup-status
|
||||
- method: POST
|
||||
path: /api/v1/auth/initialize
|
||||
- method: GET
|
||||
path: /api/v1/auth/oauth/{provider}
|
||||
- method: GET
|
||||
path: /api/v1/auth/callback/{provider}
|
||||
- method: GET
|
||||
path: /docs
|
||||
|
||||
auth:
|
||||
- method: POST
|
||||
path: /api/v1/auth/change-password
|
||||
- method: GET
|
||||
path: /api/v1/auth/me
|
||||
|
||||
threads:
|
||||
- method: POST
|
||||
path: /api/threads
|
||||
capability: threads:write
|
||||
- method: POST
|
||||
path: /api/threads/search
|
||||
capability: threads:read
|
||||
- method: DELETE
|
||||
path: /api/threads/{thread_id}
|
||||
capability: threads:delete
|
||||
policies: [owner:thread]
|
||||
require_existing: false
|
||||
- method: GET
|
||||
path: /api/threads/{thread_id}/state
|
||||
capability: threads:read
|
||||
policies: [owner:thread]
|
||||
- method: POST
|
||||
path: /api/threads/{thread_id}/state
|
||||
capability: threads:write
|
||||
policies: [owner:thread]
|
||||
- method: POST
|
||||
path: /api/threads/{thread_id}/history
|
||||
capability: threads:read
|
||||
policies: [owner:thread]
|
||||
|
||||
runs:
|
||||
- method: GET
|
||||
path: /api/threads/{thread_id}/runs
|
||||
capability: runs:read
|
||||
policies: [owner:thread]
|
||||
- method: GET
|
||||
path: /api/threads/{thread_id}/runs/{run_id}
|
||||
capability: runs:read
|
||||
policies: [owner:run]
|
||||
- method: GET
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/messages
|
||||
capability: runs:read
|
||||
policies: [owner:run]
|
||||
- method: POST
|
||||
path: /api/threads/{thread_id}/runs
|
||||
capability: runs:create
|
||||
policies: [owner:thread]
|
||||
- method: POST
|
||||
path: /api/threads/{thread_id}/runs/stream
|
||||
capability: runs:create
|
||||
policies: [owner:thread]
|
||||
- method: POST
|
||||
path: /api/threads/{thread_id}/runs/wait
|
||||
capability: runs:create
|
||||
policies: [owner:thread]
|
||||
- method: POST
|
||||
path: /api/threads/runs
|
||||
capability: runs:create
|
||||
- method: POST
|
||||
path: /api/threads/runs/stream
|
||||
capability: runs:create
|
||||
- method: POST
|
||||
path: /api/threads/runs/wait
|
||||
capability: runs:create
|
||||
- methods: [GET, POST]
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/stream
|
||||
capability: runs:read
|
||||
policies: [owner:run]
|
||||
- method: GET
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/join
|
||||
capability: runs:read
|
||||
policies: [owner:run]
|
||||
- method: POST
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/cancel
|
||||
capability: runs:cancel
|
||||
policies: [owner:run]
|
||||
- method: DELETE
|
||||
path: /api/threads/{thread_id}/runs/{run_id}
|
||||
capability: runs:cancel
|
||||
policies: [owner:run]
|
||||
|
||||
feedback:
|
||||
- method: PUT
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||
policies: [owner:run]
|
||||
- method: POST
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||
policies: [owner:run]
|
||||
- method: GET
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||
policies: [owner:run]
|
||||
- method: GET
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/feedback/stats
|
||||
policies: [owner:run]
|
||||
- method: DELETE
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/feedback
|
||||
policies: [owner:run]
|
||||
- method: DELETE
|
||||
path: /api/threads/{thread_id}/runs/{run_id}/feedback/{feedback_id}
|
||||
policies: [owner:run]
|
||||
|
||||
suggestions:
|
||||
- method: POST
|
||||
path: /api/threads/{thread_id}/suggestions
|
||||
capability: threads:read
|
||||
policies: [owner:thread]
|
||||
|
||||
uploads:
|
||||
- method: POST
|
||||
path: /api/threads/{thread_id}/uploads
|
||||
capability: threads:write
|
||||
policies: [owner:thread]
|
||||
require_existing: false
|
||||
- method: GET
|
||||
path: /api/threads/{thread_id}/uploads/list
|
||||
capability: threads:read
|
||||
policies: [owner:thread]
|
||||
- method: DELETE
|
||||
path: /api/threads/{thread_id}/uploads/{filename}
|
||||
capability: threads:delete
|
||||
policies: [owner:thread]
|
||||
|
||||
artifacts:
|
||||
- method: GET
|
||||
path: /api/threads/{thread_id}/artifacts/{path:path}
|
||||
capability: threads:read
|
||||
policies: [owner:thread]
|
||||
|
||||
agents:
|
||||
- method: GET
|
||||
path: /api/agents
|
||||
- method: GET
|
||||
path: /api/agents/check
|
||||
- method: GET
|
||||
path: /api/agents/{name}
|
||||
- method: POST
|
||||
path: /api/agents
|
||||
- method: PUT
|
||||
path: /api/agents/{name}
|
||||
- method: GET
|
||||
path: /api/user-profile
|
||||
- method: PUT
|
||||
path: /api/user-profile
|
||||
- method: DELETE
|
||||
path: /api/agents/{name}
|
||||
|
||||
channels:
|
||||
- method: GET
|
||||
path: /api/channels/
|
||||
- method: POST
|
||||
path: /api/channels/{name}/restart
|
||||
|
||||
mcp:
|
||||
- method: GET
|
||||
path: /api/mcp/config
|
||||
- method: PUT
|
||||
path: /api/mcp/config
|
||||
|
||||
models:
|
||||
- method: GET
|
||||
path: /api/models
|
||||
- method: GET
|
||||
path: /api/models/{model_name}
|
||||
|
||||
skills:
|
||||
- method: GET
|
||||
path: /api/skills
|
||||
- method: POST
|
||||
path: /api/skills/install
|
||||
- method: GET
|
||||
path: /api/skills/custom
|
||||
- method: GET
|
||||
path: /api/skills/custom/{skill_name}
|
||||
- method: PUT
|
||||
path: /api/skills/custom/{skill_name}
|
||||
- method: DELETE
|
||||
path: /api/skills/custom/{skill_name}
|
||||
- method: GET
|
||||
path: /api/skills/custom/{skill_name}/history
|
||||
- method: POST
|
||||
path: /api/skills/custom/{skill_name}/rollback
|
||||
- method: GET
|
||||
path: /api/skills/{skill_name}
|
||||
- method: PUT
|
||||
path: /api/skills/{skill_name}
|
||||
@@ -0,0 +1,5 @@
|
||||
"""Runtime state utilities for the auth plugin."""
|
||||
|
||||
from app.plugins.auth.runtime.config_state import get_auth_config, reset_auth_config, set_auth_config
|
||||
|
||||
__all__ = ["get_auth_config", "reset_auth_config", "set_auth_config"]
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Runtime state holder for auth configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from app.plugins.auth.domain.config import AuthConfig, load_auth_config_from_env
|
||||
|
||||
_auth_config: AuthConfig | None = None
|
||||
|
||||
|
||||
def get_auth_config() -> AuthConfig:
|
||||
global _auth_config
|
||||
if _auth_config is None:
|
||||
_auth_config = load_auth_config_from_env()
|
||||
return _auth_config
|
||||
|
||||
|
||||
def set_auth_config(config: AuthConfig) -> None:
|
||||
global _auth_config
|
||||
_auth_config = config
|
||||
|
||||
|
||||
def reset_auth_config() -> None:
|
||||
global _auth_config
|
||||
_auth_config = None
|
||||
|
||||
|
||||
__all__ = ["get_auth_config", "reset_auth_config", "set_auth_config"]
|
||||
@@ -0,0 +1,47 @@
|
||||
"""Security layer for the auth plugin."""
|
||||
|
||||
from app.plugins.auth.security.actor_context import (
|
||||
bind_request_actor_context,
|
||||
bind_user_actor_context,
|
||||
resolve_request_user_id,
|
||||
)
|
||||
from app.plugins.auth.security.csrf import (
|
||||
CSRF_COOKIE_NAME,
|
||||
CSRF_HEADER_NAME,
|
||||
CSRFMiddleware,
|
||||
get_csrf_token,
|
||||
is_secure_request,
|
||||
)
|
||||
from app.plugins.auth.security.dependencies import (
|
||||
CurrentAuthService,
|
||||
CurrentUserRepository,
|
||||
get_auth_service,
|
||||
get_current_user_from_request,
|
||||
get_current_user_id,
|
||||
get_optional_user_from_request,
|
||||
get_user_repository,
|
||||
)
|
||||
from app.plugins.auth.security.langgraph import add_owner_filter, auth, authenticate
|
||||
from app.plugins.auth.security.middleware import AuthMiddleware
|
||||
|
||||
__all__ = [
|
||||
"CSRF_COOKIE_NAME",
|
||||
"CSRF_HEADER_NAME",
|
||||
"CSRFMiddleware",
|
||||
"AuthMiddleware",
|
||||
"CurrentAuthService",
|
||||
"CurrentUserRepository",
|
||||
"add_owner_filter",
|
||||
"auth",
|
||||
"authenticate",
|
||||
"bind_request_actor_context",
|
||||
"bind_user_actor_context",
|
||||
"get_auth_service",
|
||||
"get_csrf_token",
|
||||
"get_current_user_from_request",
|
||||
"get_current_user_id",
|
||||
"get_optional_user_from_request",
|
||||
"get_user_repository",
|
||||
"is_secure_request",
|
||||
"resolve_request_user_id",
|
||||
]
|
||||
@@ -0,0 +1,43 @@
|
||||
"""Auth-plugin bridge from request user to runtime actor context."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
|
||||
from fastapi import Request
|
||||
|
||||
from deerflow.runtime.actor_context import ActorContext, bind_actor_context, reset_actor_context
|
||||
|
||||
|
||||
def resolve_request_user_id(request: Request) -> str | None:
|
||||
scope = getattr(request, "scope", None)
|
||||
user = scope.get("user") if isinstance(scope, dict) else None
|
||||
if user is None:
|
||||
state = getattr(request, "state", None)
|
||||
state_vars = vars(state) if state is not None and hasattr(state, "__dict__") else {}
|
||||
user = state_vars.get("user")
|
||||
user_id = getattr(user, "id", None)
|
||||
if user_id is None:
|
||||
return None
|
||||
return str(user_id)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_request_actor_context(request: Request):
|
||||
token = bind_actor_context(ActorContext(user_id=resolve_request_user_id(request)))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_actor_context(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def bind_user_actor_context(user_id: str | None):
|
||||
token = bind_actor_context(ActorContext(user_id=str(user_id) if user_id is not None else None))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
reset_actor_context(token)
|
||||
|
||||
|
||||
__all__ = ["bind_request_actor_context", "bind_user_actor_context", "resolve_request_user_id"]
|
||||
@@ -0,0 +1,106 @@
|
||||
"""CSRF protection middleware and helpers for cookie-based auth flows."""
|
||||
|
||||
import secrets
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
CSRF_COOKIE_NAME = "csrf_token"
|
||||
CSRF_HEADER_NAME = "X-CSRF-Token"
|
||||
CSRF_TOKEN_LENGTH = 64 # bytes
|
||||
|
||||
|
||||
def is_secure_request(request: Request) -> bool:
|
||||
"""Detect whether the original client request was made over HTTPS."""
|
||||
return request.headers.get("x-forwarded-proto", request.url.scheme) == "https"
|
||||
|
||||
|
||||
def generate_csrf_token() -> str:
|
||||
"""Generate a secure random CSRF token."""
|
||||
return secrets.token_urlsafe(CSRF_TOKEN_LENGTH)
|
||||
|
||||
|
||||
def should_check_csrf(request: Request) -> bool:
|
||||
"""Determine if a request needs CSRF validation."""
|
||||
if request.method not in ("POST", "PUT", "DELETE", "PATCH"):
|
||||
return False
|
||||
|
||||
path = request.url.path.rstrip("/")
|
||||
if path == "/api/v1/auth/me":
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
_AUTH_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/initialize",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def is_auth_endpoint(request: Request) -> bool:
|
||||
"""Check if the request is to an auth endpoint."""
|
||||
return request.url.path.rstrip("/") in _AUTH_EXEMPT_PATHS
|
||||
|
||||
|
||||
class CSRFMiddleware(BaseHTTPMiddleware):
|
||||
"""Implement CSRF protection using the double-submit cookie pattern."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
_is_auth = is_auth_endpoint(request)
|
||||
|
||||
if should_check_csrf(request) and not _is_auth:
|
||||
cookie_token = request.cookies.get(CSRF_COOKIE_NAME)
|
||||
header_token = request.headers.get(CSRF_HEADER_NAME)
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token missing. Include X-CSRF-Token header."},
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={"detail": "CSRF token mismatch."},
|
||||
)
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
if _is_auth and request.method == "POST":
|
||||
csrf_token = generate_csrf_token()
|
||||
response.set_cookie(
|
||||
key=CSRF_COOKIE_NAME,
|
||||
value=csrf_token,
|
||||
httponly=False,
|
||||
secure=is_secure_request(request),
|
||||
samesite="strict",
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def get_csrf_token(request: Request) -> str | None:
|
||||
"""Get the CSRF token from the current request's cookies."""
|
||||
return request.cookies.get(CSRF_COOKIE_NAME)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"CSRF_COOKIE_NAME",
|
||||
"CSRF_HEADER_NAME",
|
||||
"CSRFMiddleware",
|
||||
"generate_csrf_token",
|
||||
"get_csrf_token",
|
||||
"is_auth_endpoint",
|
||||
"is_secure_request",
|
||||
"should_check_csrf",
|
||||
]
|
||||
@@ -0,0 +1,119 @@
|
||||
"""Security dependency helpers for the auth plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.plugins.auth.domain.errors import (
|
||||
AuthErrorCode,
|
||||
AuthErrorResponse,
|
||||
TokenError,
|
||||
token_error_to_code,
|
||||
)
|
||||
from app.plugins.auth.domain.jwt import decode_token
|
||||
from app.plugins.auth.domain.service import AuthService
|
||||
from app.plugins.auth.storage import DbUserRepository, UserRepositoryProtocol
|
||||
|
||||
|
||||
def _get_session_factory(request: Request) -> async_sessionmaker[AsyncSession] | None:
|
||||
persistence = getattr(request.app.state, "persistence", None)
|
||||
if persistence is None:
|
||||
return None
|
||||
return getattr(persistence, "session_factory", None)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _auth_session(request: Request) -> AsyncIterator[AsyncSession]:
|
||||
injected = getattr(request.state, "_auth_session", None)
|
||||
if injected is not None:
|
||||
yield injected
|
||||
return
|
||||
|
||||
session_factory = _get_session_factory(request)
|
||||
if session_factory is None:
|
||||
raise HTTPException(status_code=503, detail="Auth session not available")
|
||||
|
||||
async with session_factory() as session:
|
||||
yield session
|
||||
|
||||
|
||||
async def get_user_repository(request: Request) -> UserRepositoryProtocol:
|
||||
async with _auth_session(request) as session:
|
||||
return DbUserRepository(session)
|
||||
|
||||
|
||||
def get_auth_service(request: Request) -> AuthService:
|
||||
session_factory = _get_session_factory(request)
|
||||
if session_factory is None:
|
||||
raise HTTPException(status_code=503, detail="Auth session factory not available")
|
||||
return AuthService(session_factory)
|
||||
|
||||
|
||||
async def get_current_user_from_request(request: Request):
|
||||
access_token = request.cookies.get("access_token")
|
||||
if not access_token:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(),
|
||||
)
|
||||
|
||||
payload = decode_token(access_token)
|
||||
if isinstance(payload, TokenError):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(
|
||||
code=token_error_to_code(payload),
|
||||
message=f"Token error: {payload.value}",
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
async with _auth_session(request) as session:
|
||||
user_repo = DbUserRepository(session)
|
||||
user = await user_repo.get_user_by_id(payload.sub)
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(),
|
||||
)
|
||||
|
||||
if user.token_version != payload.ver:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=AuthErrorResponse(
|
||||
code=AuthErrorCode.TOKEN_INVALID,
|
||||
message="Token revoked (password changed)",
|
||||
).model_dump(),
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
async def get_optional_user_from_request(request: Request):
|
||||
try:
|
||||
return await get_current_user_from_request(request)
|
||||
except HTTPException:
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user_id(request: Request) -> str | None:
|
||||
user = await get_optional_user_from_request(request)
|
||||
return user.id if user else None
|
||||
|
||||
|
||||
CurrentUserRepository = Annotated[UserRepositoryProtocol, Depends(get_user_repository)]
|
||||
CurrentAuthService = Annotated[AuthService, Depends(get_auth_service)]
|
||||
|
||||
__all__ = [
|
||||
"CurrentAuthService",
|
||||
"CurrentUserRepository",
|
||||
"get_auth_service",
|
||||
"get_current_user_from_request",
|
||||
"get_current_user_id",
|
||||
"get_optional_user_from_request",
|
||||
"get_user_repository",
|
||||
]
|
||||
@@ -0,0 +1,64 @@
|
||||
"""LangGraph auth adapter for the auth plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import secrets
|
||||
from types import SimpleNamespace
|
||||
|
||||
from langgraph_sdk import Auth
|
||||
|
||||
from app.plugins.auth.security.dependencies import get_current_user_from_request
|
||||
|
||||
auth = Auth()
|
||||
|
||||
_CSRF_METHODS = frozenset({"POST", "PUT", "DELETE", "PATCH"})
|
||||
|
||||
|
||||
def _check_csrf(request) -> None:
|
||||
method = getattr(request, "method", "") or ""
|
||||
if method.upper() not in _CSRF_METHODS:
|
||||
return
|
||||
|
||||
cookie_token = request.cookies.get("csrf_token")
|
||||
header_token = request.headers.get("x-csrf-token")
|
||||
|
||||
if not cookie_token or not header_token:
|
||||
raise Auth.exceptions.HTTPException(
|
||||
status_code=403,
|
||||
detail="CSRF token missing. Include X-CSRF-Token header.",
|
||||
)
|
||||
|
||||
if not secrets.compare_digest(cookie_token, header_token):
|
||||
raise Auth.exceptions.HTTPException(status_code=403, detail="CSRF token mismatch.")
|
||||
|
||||
|
||||
@auth.authenticate
|
||||
async def authenticate(request):
|
||||
_check_csrf(request)
|
||||
resolver_request = SimpleNamespace(
|
||||
cookies=getattr(request, "cookies", {}),
|
||||
state=SimpleNamespace(_auth_session=getattr(request, "_auth_session", None)),
|
||||
app=SimpleNamespace(state=SimpleNamespace(persistence=getattr(request, "_persistence", None))),
|
||||
)
|
||||
|
||||
try:
|
||||
user = await get_current_user_from_request(resolver_request)
|
||||
except Exception as exc:
|
||||
status_code = getattr(exc, "status_code", None)
|
||||
if status_code is None:
|
||||
raise
|
||||
detail = getattr(exc, "detail", "Not authenticated")
|
||||
message = detail.get("message") if isinstance(detail, dict) else str(detail)
|
||||
raise Auth.exceptions.HTTPException(status_code=status_code, detail=message) from exc
|
||||
|
||||
return user.id
|
||||
|
||||
|
||||
@auth.on
|
||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||
metadata = value.setdefault("metadata", {})
|
||||
metadata["user_id"] = ctx.user.identity
|
||||
return {"user_id": ctx.user.identity}
|
||||
|
||||
|
||||
__all__ = ["add_owner_filter", "auth", "authenticate"]
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Global authentication middleware for the auth plugin."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import HTTPException, Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.plugins.auth.authorization import _ALL_PERMISSIONS, AuthContext
|
||||
from app.plugins.auth.domain.errors import AuthErrorCode, AuthErrorResponse
|
||||
from app.plugins.auth.injection.registry_loader import RoutePolicyRegistry
|
||||
from app.plugins.auth.security.dependencies import get_current_user_from_request
|
||||
from deerflow.runtime.actor_context import ActorContext, bind_actor_context, reset_actor_context
|
||||
|
||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = ("/health", "/docs", "/redoc", "/openapi.json")
|
||||
|
||||
_PUBLIC_EXACT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/login/local",
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/logout",
|
||||
"/api/v1/auth/setup-status",
|
||||
"/api/v1/auth/initialize",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _is_public(path: str) -> bool:
|
||||
stripped = path.rstrip("/")
|
||||
if stripped in _PUBLIC_EXACT_PATHS:
|
||||
return True
|
||||
return any(path.startswith(prefix) for prefix in _PUBLIC_PATH_PREFIXES)
|
||||
|
||||
|
||||
class AuthMiddleware(BaseHTTPMiddleware):
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
registry = getattr(request.app.state, "auth_route_policy_registry", None)
|
||||
is_public = False
|
||||
if isinstance(registry, RoutePolicyRegistry):
|
||||
is_public = registry.is_public_request(request.method, request.url.path)
|
||||
if is_public or _is_public(request.url.path):
|
||||
return await call_next(request)
|
||||
|
||||
if not request.cookies.get("access_token"):
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"detail": AuthErrorResponse(
|
||||
code=AuthErrorCode.NOT_AUTHENTICATED,
|
||||
message="Authentication required",
|
||||
).model_dump()
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
user = await get_current_user_from_request(request)
|
||||
except HTTPException as exc:
|
||||
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
|
||||
|
||||
auth_context = AuthContext(user=user, permissions=_ALL_PERMISSIONS)
|
||||
request.scope["user"] = user
|
||||
request.scope["auth"] = auth_context
|
||||
request.state.user = user
|
||||
request.state.auth = auth_context
|
||||
token = bind_actor_context(ActorContext(user_id=str(user.id)))
|
||||
try:
|
||||
return await call_next(request)
|
||||
finally:
|
||||
reset_actor_context(token)
|
||||
|
||||
|
||||
__all__ = ["AuthMiddleware", "_is_public"]
|
||||
@@ -0,0 +1,17 @@
|
||||
"""Auth plugin storage package.
|
||||
|
||||
This package owns auth-specific ORM models and repositories while
|
||||
continuing to use the application's shared persistence infrastructure.
|
||||
"""
|
||||
|
||||
from app.plugins.auth.storage.contracts import User, UserCreate, UserRepositoryProtocol
|
||||
from app.plugins.auth.storage.models import User as UserModel
|
||||
from app.plugins.auth.storage.repositories import DbUserRepository
|
||||
|
||||
__all__ = [
|
||||
"DbUserRepository",
|
||||
"User",
|
||||
"UserCreate",
|
||||
"UserModel",
|
||||
"UserRepositoryProtocol",
|
||||
]
|
||||
@@ -0,0 +1,55 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
def _new_user_id() -> str:
|
||||
return str(uuid4())
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: str = Field(default_factory=_new_user_id)
|
||||
email: str
|
||||
password_hash: str | None = None
|
||||
system_role: str = "user"
|
||||
oauth_provider: str | None = None
|
||||
oauth_id: str | None = None
|
||||
needs_setup: bool = False
|
||||
token_version: int = 0
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
id: str
|
||||
email: str
|
||||
password_hash: str | None
|
||||
system_role: str
|
||||
oauth_provider: str | None
|
||||
oauth_id: str | None
|
||||
needs_setup: bool
|
||||
token_version: int
|
||||
created_time: datetime
|
||||
updated_time: datetime | None
|
||||
|
||||
|
||||
class UserRepositoryProtocol(Protocol):
|
||||
async def create_user(self, data: UserCreate) -> User: ...
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None: ...
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None: ...
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: ...
|
||||
|
||||
async def update_user(self, data: User) -> User: ...
|
||||
|
||||
async def count_users(self) -> int: ...
|
||||
|
||||
async def count_admin_users(self) -> int: ...
|
||||
@@ -0,0 +1,25 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import Boolean, Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import Base
|
||||
|
||||
|
||||
class User(Base):
|
||||
"""Application user table."""
|
||||
|
||||
__tablename__ = "users"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("oauth_provider", "oauth_id", name="uq_users_oauth_identity"),
|
||||
{"comment": "Application user table."},
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(64), primary_key=True, unique=True, index=True)
|
||||
email: Mapped[str] = mapped_column(String(255), unique=True, index=True)
|
||||
password_hash: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
system_role: Mapped[str] = mapped_column(String(16), default="user", index=True)
|
||||
oauth_provider: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
oauth_id: Mapped[str | None] = mapped_column(String(255), default=None)
|
||||
needs_setup: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
token_version: Mapped[int] = mapped_column(Integer, default=0)
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.plugins.auth.storage.contracts import User, UserCreate, UserRepositoryProtocol
|
||||
from app.plugins.auth.storage.models import User as UserModel
|
||||
|
||||
|
||||
def _to_user(model: UserModel) -> User:
|
||||
return User(
|
||||
id=model.id,
|
||||
email=model.email,
|
||||
password_hash=model.password_hash,
|
||||
system_role=model.system_role,
|
||||
oauth_provider=model.oauth_provider,
|
||||
oauth_id=model.oauth_id,
|
||||
needs_setup=model.needs_setup,
|
||||
token_version=model.token_version,
|
||||
created_time=model.created_time,
|
||||
updated_time=model.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbUserRepository(UserRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_user(self, data: UserCreate) -> User:
|
||||
model = UserModel(
|
||||
id=data.id,
|
||||
email=data.email,
|
||||
password_hash=data.password_hash,
|
||||
system_role=data.system_role,
|
||||
oauth_provider=data.oauth_provider,
|
||||
oauth_id=data.oauth_id,
|
||||
needs_setup=data.needs_setup,
|
||||
token_version=data.token_version,
|
||||
)
|
||||
self._session.add(model)
|
||||
try:
|
||||
await self._session.flush()
|
||||
except IntegrityError as exc:
|
||||
await self._session.rollback()
|
||||
raise ValueError("User already exists") from exc
|
||||
await self._session.refresh(model)
|
||||
return _to_user(model)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
model = await self._session.get(UserModel, user_id)
|
||||
return _to_user(model) if model else None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
result = await self._session.execute(select(UserModel).where(UserModel.email == email))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model else None
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
result = await self._session.execute(
|
||||
select(UserModel).where(
|
||||
UserModel.oauth_provider == provider,
|
||||
UserModel.oauth_id == oauth_id,
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model else None
|
||||
|
||||
async def update_user(self, data: User) -> User:
|
||||
model = await self._session.get(UserModel, data.id)
|
||||
if model is None:
|
||||
raise LookupError(f"User {data.id} not found")
|
||||
model.email = data.email
|
||||
model.password_hash = data.password_hash
|
||||
model.system_role = data.system_role
|
||||
model.oauth_provider = data.oauth_provider
|
||||
model.oauth_id = data.oauth_id
|
||||
model.needs_setup = data.needs_setup
|
||||
model.token_version = data.token_version
|
||||
try:
|
||||
await self._session.flush()
|
||||
except IntegrityError as exc:
|
||||
await self._session.rollback()
|
||||
raise ValueError("User already exists") from exc
|
||||
await self._session.refresh(model)
|
||||
return _to_user(model)
|
||||
|
||||
async def count_users(self) -> int:
|
||||
return await self._session.scalar(select(func.count()).select_from(UserModel)) or 0
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
return (
|
||||
await self._session.scalar(
|
||||
select(func.count()).select_from(UserModel).where(UserModel.system_role == "admin")
|
||||
)
|
||||
or 0
|
||||
)
|
||||
Reference in New Issue
Block a user