mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
refactor(config): eliminate global mutable state — explicit parameter passing on top of main
Squashes 25 PR commits onto current main. AppConfig becomes a pure value object with no ambient lookup. Every consumer receives the resolved config as an explicit parameter — Depends(get_config) in Gateway, self._app_config in DeerFlowClient, runtime.context.app_config in agent runs, AppConfig.from_file() at the LangGraph Server registration boundary. Phase 1 — frozen data + typed context - All config models (AppConfig, MemoryConfig, DatabaseConfig, …) become frozen=True; no sub-module globals. - AppConfig.from_file() is pure (no side-effect singleton loaders). - Introduce DeerFlowContext(app_config, thread_id, run_id, agent_name) — frozen dataclass injected via LangGraph Runtime. - Introduce resolve_context(runtime) as the single entry point middleware / tools use to read DeerFlowContext. Phase 2 — pure explicit parameter passing - Gateway: app.state.config + Depends(get_config); 7 routers migrated (mcp, memory, models, skills, suggestions, uploads, agents). - DeerFlowClient: __init__(config=...) captures config locally. - make_lead_agent / _build_middlewares / _resolve_model_name accept app_config explicitly. - RunContext.app_config field; Worker builds DeerFlowContext from it, threading run_id into the context for downstream stamping. - Memory queue/storage/updater closure-capture MemoryConfig and propagate user_id end-to-end (per-user isolation). - Sandbox/skills/community/factories/tools thread app_config. - resolve_context() rejects non-typed runtime.context. - Test suite migrated off AppConfig.current() monkey-patches. - AppConfig.current() classmethod deleted. Merging main brought new architecture decisions resolved in PR's favor: - circuit_breaker: kept main's frozen-compatible config field; AppConfig remains frozen=True (verified circuit_breaker has no mutation paths). - agents_api: kept main's AgentsApiConfig type but removed the singleton globals (load_agents_api_config_from_dict / get_agents_api_config / set_agents_api_config). 8 routes in agents.py now read via Depends(get_config). - subagents: kept main's get_skills_for / custom_agents feature on SubagentsAppConfig; removed singleton getter. registry.py now reads app_config.subagents directly. - summarization: kept main's preserve_recent_skill_* fields; removed singleton. - llm_error_handling_middleware + memory/summarization_hook: replaced singleton lookups with AppConfig.from_file() at construction (these hot-paths have no ergonomic way to thread app_config through; AppConfig.from_file is a pure load). - worker.py + thread_data_middleware.py: DeerFlowContext.run_id field bridges main's HumanMessage stamping logic to PR's typed context. Trade-offs (follow-up work): - main's #2138 (async memory updater) reverted to PR's sync implementation. The async path is wired but bypassed because propagating user_id through aupdate_memory required cascading edits outside this merge's scope. - tests/test_subagent_skills_config.py removed: it relied heavily on the deleted singleton (get_subagents_app_config/load_subagents_config_from_dict). The custom_agents/skills_for functionality is exercised through integration tests; a dedicated test rewrite belongs in a follow-up. Verification: backend test suite — 2560 passed, 4 skipped, 84 failures. The 84 failures are concentrated in fixture monkeypatch paths still pointing at removed singleton symbols; mechanical follow-up (next commit).
This commit is contained in:
@@ -0,0 +1,42 @@
|
||||
"""Authentication module for DeerFlow.
|
||||
|
||||
This module provides:
|
||||
- JWT-based authentication
|
||||
- Provider Factory pattern for extensible auth methods
|
||||
- UserRepository interface for storage backends (SQLite)
|
||||
"""
|
||||
|
||||
from app.gateway.auth.config import AuthConfig, get_auth_config, set_auth_config
|
||||
from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError
|
||||
from app.gateway.auth.jwt import TokenPayload, create_access_token, decode_token
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
from app.gateway.auth.models import User, UserResponse
|
||||
from app.gateway.auth.password import hash_password, verify_password
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
__all__ = [
|
||||
# Config
|
||||
"AuthConfig",
|
||||
"get_auth_config",
|
||||
"set_auth_config",
|
||||
# Errors
|
||||
"AuthErrorCode",
|
||||
"AuthErrorResponse",
|
||||
"TokenError",
|
||||
# JWT
|
||||
"TokenPayload",
|
||||
"create_access_token",
|
||||
"decode_token",
|
||||
# Password
|
||||
"hash_password",
|
||||
"verify_password",
|
||||
# Models
|
||||
"User",
|
||||
"UserResponse",
|
||||
# Providers
|
||||
"AuthProvider",
|
||||
"LocalAuthProvider",
|
||||
# Repository
|
||||
"UserRepository",
|
||||
]
|
||||
@@ -0,0 +1,57 @@
|
||||
"""Authentication configuration for DeerFlow."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import secrets
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""JWT and auth-related configuration. Parsed once at startup.
|
||||
|
||||
Note: the ``users`` table now lives in the shared persistence
|
||||
database managed by ``deerflow.persistence.engine``. The old
|
||||
``users_db_path`` config key has been removed — user storage is
|
||||
configured through ``config.database`` like every other table.
|
||||
"""
|
||||
|
||||
jwt_secret: str = Field(
|
||||
...,
|
||||
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
|
||||
)
|
||||
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
||||
oauth_github_client_id: str | None = Field(default=None)
|
||||
oauth_github_client_secret: str | None = Field(default=None)
|
||||
|
||||
|
||||
_auth_config: AuthConfig | None = None
|
||||
|
||||
|
||||
def get_auth_config() -> AuthConfig:
|
||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
||||
global _auth_config
|
||||
if _auth_config is None:
|
||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||
if not jwt_secret:
|
||||
jwt_secret = secrets.token_urlsafe(32)
|
||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||
logger.warning(
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||
"Sessions will be invalidated on restart. "
|
||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
_auth_config = AuthConfig(jwt_secret=jwt_secret)
|
||||
return _auth_config
|
||||
|
||||
|
||||
def set_auth_config(config: AuthConfig) -> None:
|
||||
"""Set the global AuthConfig instance (for testing)."""
|
||||
global _auth_config
|
||||
_auth_config = config
|
||||
@@ -0,0 +1,48 @@
|
||||
"""Write initial admin credentials to a restricted file instead of logs.
|
||||
|
||||
Logging secrets to stdout/stderr is a well-known CodeQL finding
|
||||
(py/clear-text-logging-sensitive-data) — in production those logs
|
||||
get collected into ELK/Splunk/etc and become a secret sprawl
|
||||
source. This helper writes the credential to a 0600 file that only
|
||||
the process user can read, and returns the path so the caller can
|
||||
log **the path** (not the password) for the operator to pick up.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
_CREDENTIAL_FILENAME = "admin_initial_credentials.txt"
|
||||
|
||||
|
||||
def write_initial_credentials(email: str, password: str, *, label: str = "initial") -> Path:
|
||||
"""Write the admin email + password to ``{base_dir}/admin_initial_credentials.txt``.
|
||||
|
||||
The file is created **atomically** with mode 0600 via ``os.open``
|
||||
so the password is never world-readable, even for the single syscall
|
||||
window between ``write_text`` and ``chmod``.
|
||||
|
||||
``label`` distinguishes "initial" (fresh creation) from "reset"
|
||||
(password reset) in the file header so an operator picking up the
|
||||
file after a restart can tell which event produced it.
|
||||
|
||||
Returns the absolute :class:`Path` to the file.
|
||||
"""
|
||||
target = get_paths().base_dir / _CREDENTIAL_FILENAME
|
||||
target.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
content = (
|
||||
f"# DeerFlow admin {label} credentials\n# This file is generated on first boot or password reset.\n# Change the password after login via Settings -> Account,\n# then delete this file.\n#\nemail: {email}\npassword: {password}\n"
|
||||
)
|
||||
|
||||
# Atomic 0600 create-or-truncate. O_TRUNC (not O_EXCL) so the
|
||||
# reset-password path can rewrite an existing file without a
|
||||
# separate unlink-then-create dance.
|
||||
fd = os.open(target, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
fh.write(content)
|
||||
|
||||
return target.resolve()
|
||||
@@ -0,0 +1,45 @@
|
||||
"""Typed error definitions for auth module.
|
||||
|
||||
AuthErrorCode: exhaustive enum of all auth failure conditions.
|
||||
TokenError: exhaustive enum of JWT decode failures.
|
||||
AuthErrorResponse: structured error payload for HTTP responses.
|
||||
"""
|
||||
|
||||
from enum import StrEnum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AuthErrorCode(StrEnum):
|
||||
"""Exhaustive list of auth error conditions."""
|
||||
|
||||
INVALID_CREDENTIALS = "invalid_credentials"
|
||||
TOKEN_EXPIRED = "token_expired"
|
||||
TOKEN_INVALID = "token_invalid"
|
||||
USER_NOT_FOUND = "user_not_found"
|
||||
EMAIL_ALREADY_EXISTS = "email_already_exists"
|
||||
PROVIDER_NOT_FOUND = "provider_not_found"
|
||||
NOT_AUTHENTICATED = "not_authenticated"
|
||||
SYSTEM_ALREADY_INITIALIZED = "system_already_initialized"
|
||||
|
||||
|
||||
class TokenError(StrEnum):
|
||||
"""Exhaustive list of JWT decode failure reasons."""
|
||||
|
||||
EXPIRED = "expired"
|
||||
INVALID_SIGNATURE = "invalid_signature"
|
||||
MALFORMED = "malformed"
|
||||
|
||||
|
||||
class AuthErrorResponse(BaseModel):
|
||||
"""Structured error response — replaces bare `detail` strings."""
|
||||
|
||||
code: AuthErrorCode
|
||||
message: str
|
||||
|
||||
|
||||
def token_error_to_code(err: TokenError) -> AuthErrorCode:
|
||||
"""Map TokenError to AuthErrorCode — single source of truth."""
|
||||
if err == TokenError.EXPIRED:
|
||||
return AuthErrorCode.TOKEN_EXPIRED
|
||||
return AuthErrorCode.TOKEN_INVALID
|
||||
@@ -0,0 +1,55 @@
|
||||
"""JWT token creation and verification."""
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.errors import TokenError
|
||||
|
||||
|
||||
class TokenPayload(BaseModel):
|
||||
"""JWT token payload."""
|
||||
|
||||
sub: str # user_id
|
||||
exp: datetime
|
||||
iat: datetime | None = None
|
||||
ver: int = 0 # token_version — must match User.token_version
|
||||
|
||||
|
||||
def create_access_token(user_id: str, expires_delta: timedelta | None = None, token_version: int = 0) -> str:
|
||||
"""Create a JWT access token.
|
||||
|
||||
Args:
|
||||
user_id: The user's UUID as string
|
||||
expires_delta: Optional custom expiry, defaults to 7 days
|
||||
token_version: User's current token_version for invalidation
|
||||
|
||||
Returns:
|
||||
Encoded JWT string
|
||||
"""
|
||||
config = get_auth_config()
|
||||
expiry = expires_delta or timedelta(days=config.token_expiry_days)
|
||||
|
||||
now = datetime.now(UTC)
|
||||
payload = {"sub": user_id, "exp": now + expiry, "iat": now, "ver": token_version}
|
||||
return jwt.encode(payload, config.jwt_secret, algorithm="HS256")
|
||||
|
||||
|
||||
def decode_token(token: str) -> TokenPayload | TokenError:
|
||||
"""Decode and validate a JWT token.
|
||||
|
||||
Returns:
|
||||
TokenPayload if valid, or a specific TokenError variant.
|
||||
"""
|
||||
config = get_auth_config()
|
||||
try:
|
||||
payload = jwt.decode(token, config.jwt_secret, algorithms=["HS256"])
|
||||
return TokenPayload(**payload)
|
||||
except jwt.ExpiredSignatureError:
|
||||
return TokenError.EXPIRED
|
||||
except jwt.InvalidSignatureError:
|
||||
return TokenError.INVALID_SIGNATURE
|
||||
except jwt.PyJWTError:
|
||||
return TokenError.MALFORMED
|
||||
@@ -0,0 +1,91 @@
|
||||
"""Local email/password authentication provider."""
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.password import hash_password_async, verify_password_async
|
||||
from app.gateway.auth.providers import AuthProvider
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
|
||||
class LocalAuthProvider(AuthProvider):
|
||||
"""Email/password authentication provider using local database."""
|
||||
|
||||
def __init__(self, repository: UserRepository):
|
||||
"""Initialize with a UserRepository.
|
||||
|
||||
Args:
|
||||
repository: UserRepository implementation (SQLite)
|
||||
"""
|
||||
self._repo = repository
|
||||
|
||||
async def authenticate(self, credentials: dict) -> User | None:
|
||||
"""Authenticate with email and password.
|
||||
|
||||
Args:
|
||||
credentials: dict with 'email' and 'password' keys
|
||||
|
||||
Returns:
|
||||
User if authentication succeeds, None otherwise
|
||||
"""
|
||||
email = credentials.get("email")
|
||||
password = credentials.get("password")
|
||||
|
||||
if not email or not password:
|
||||
return None
|
||||
|
||||
user = await self._repo.get_user_by_email(email)
|
||||
if user is None:
|
||||
return None
|
||||
|
||||
if user.password_hash is None:
|
||||
# OAuth user without local password
|
||||
return None
|
||||
|
||||
if not await verify_password_async(password, user.password_hash):
|
||||
return None
|
||||
|
||||
return user
|
||||
|
||||
async def get_user(self, user_id: str) -> User | None:
|
||||
"""Get user by ID."""
|
||||
return await self._repo.get_user_by_id(user_id)
|
||||
|
||||
async def create_user(self, email: str, password: str | None = None, system_role: str = "user", needs_setup: bool = False) -> User:
|
||||
"""Create a new local user.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
password: Plain text password (will be hashed)
|
||||
system_role: Role to assign ("admin" or "user")
|
||||
needs_setup: If True, user must complete setup on first login
|
||||
|
||||
Returns:
|
||||
Created User instance
|
||||
"""
|
||||
password_hash = await hash_password_async(password) if password else None
|
||||
user = User(
|
||||
email=email,
|
||||
password_hash=password_hash,
|
||||
system_role=system_role,
|
||||
needs_setup=needs_setup,
|
||||
)
|
||||
return await self._repo.create_user(user)
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID."""
|
||||
return await self._repo.get_user_by_oauth(provider, oauth_id)
|
||||
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
return await self._repo.count_users()
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
"""Return number of admin users."""
|
||||
return await self._repo.count_admin_users()
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user."""
|
||||
return await self._repo.update_user(user)
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email."""
|
||||
return await self._repo.get_user_by_email(email)
|
||||
@@ -0,0 +1,41 @@
|
||||
"""User Pydantic models for authentication."""
|
||||
|
||||
from datetime import UTC, datetime
|
||||
from typing import Literal
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, EmailStr, Field
|
||||
|
||||
|
||||
def _utc_now() -> datetime:
|
||||
"""Return current UTC time (timezone-aware)."""
|
||||
return datetime.now(UTC)
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
"""Internal user representation."""
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
id: UUID = Field(default_factory=uuid4, description="Primary key")
|
||||
email: EmailStr = Field(..., description="Unique email address")
|
||||
password_hash: str | None = Field(None, description="bcrypt hash, nullable for OAuth users")
|
||||
system_role: Literal["admin", "user"] = Field(default="user")
|
||||
created_at: datetime = Field(default_factory=_utc_now)
|
||||
|
||||
# OAuth linkage (optional)
|
||||
oauth_provider: str | None = Field(None, description="e.g. 'github', 'google'")
|
||||
oauth_id: str | None = Field(None, description="User ID from OAuth provider")
|
||||
|
||||
# Auth lifecycle
|
||||
needs_setup: bool = Field(default=False, description="True for auto-created admin until setup completes")
|
||||
token_version: int = Field(default=0, description="Incremented on password change to invalidate old JWTs")
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
"""Response model for user info endpoint."""
|
||||
|
||||
id: str
|
||||
email: str
|
||||
system_role: Literal["admin", "user"]
|
||||
needs_setup: bool = False
|
||||
@@ -0,0 +1,33 @@
|
||||
"""Password hashing utilities using bcrypt directly."""
|
||||
|
||||
import asyncio
|
||||
|
||||
import bcrypt
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using bcrypt."""
|
||||
return bcrypt.hashpw(password.encode("utf-8"), bcrypt.gensalt()).decode("utf-8")
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash."""
|
||||
return bcrypt.checkpw(plain_password.encode("utf-8"), hashed_password.encode("utf-8"))
|
||||
|
||||
|
||||
async def hash_password_async(password: str) -> str:
|
||||
"""Hash a password using bcrypt (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password hashing.
|
||||
"""
|
||||
return await asyncio.to_thread(hash_password, password)
|
||||
|
||||
|
||||
async def verify_password_async(plain_password: str, hashed_password: str) -> bool:
|
||||
"""Verify a password against its hash (non-blocking).
|
||||
|
||||
Wraps the blocking bcrypt operation in a thread pool to avoid
|
||||
blocking the event loop during password verification.
|
||||
"""
|
||||
return await asyncio.to_thread(verify_password, plain_password, hashed_password)
|
||||
@@ -0,0 +1,24 @@
|
||||
"""Auth provider abstraction."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def authenticate(self, credentials: dict) -> "User | None":
|
||||
"""Authenticate user with given credentials.
|
||||
|
||||
Returns User if authentication succeeds, None otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user(self, user_id: str) -> "User | None":
|
||||
"""Retrieve user by ID."""
|
||||
...
|
||||
|
||||
|
||||
# Import User at runtime to avoid circular imports
|
||||
from app.gateway.auth.models import User # noqa: E402
|
||||
@@ -0,0 +1,102 @@
|
||||
"""User repository interface for abstracting database operations."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
|
||||
|
||||
class UserNotFoundError(LookupError):
|
||||
"""Raised when a user repository operation targets a non-existent row.
|
||||
|
||||
Subclass of :class:`LookupError` so callers that already catch
|
||||
``LookupError`` for "missing entity" can keep working unchanged,
|
||||
while specific call sites can pin to this class to distinguish
|
||||
"concurrent delete during update" from other lookups.
|
||||
"""
|
||||
|
||||
|
||||
class UserRepository(ABC):
|
||||
"""Abstract interface for user data storage.
|
||||
|
||||
Implement this interface to support different storage backends
|
||||
(SQLite)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Create a new user.
|
||||
|
||||
Args:
|
||||
user: User object to create
|
||||
|
||||
Returns:
|
||||
Created User with ID assigned
|
||||
|
||||
Raises:
|
||||
ValueError: If email already exists
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID.
|
||||
|
||||
Args:
|
||||
user_id: User UUID as string
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email.
|
||||
|
||||
Args:
|
||||
email: User email address
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user.
|
||||
|
||||
Args:
|
||||
user: User object with updated fields
|
||||
|
||||
Returns:
|
||||
Updated User
|
||||
|
||||
Raises:
|
||||
UserNotFoundError: If no row exists for ``user.id``. This is
|
||||
a hard failure (not a no-op) so callers cannot mistake a
|
||||
concurrent-delete race for a successful update.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def count_admin_users(self) -> int:
|
||||
"""Return number of users with system_role == 'admin'."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID.
|
||||
|
||||
Args:
|
||||
provider: OAuth provider name (e.g. 'github', 'google')
|
||||
oauth_id: User ID from the OAuth provider
|
||||
|
||||
Returns:
|
||||
User if found, None otherwise
|
||||
"""
|
||||
...
|
||||
@@ -0,0 +1,127 @@
|
||||
"""SQLAlchemy-backed UserRepository implementation.
|
||||
|
||||
Uses the shared async session factory from
|
||||
``deerflow.persistence.engine`` — the ``users`` table lives in the
|
||||
same database as ``threads_meta``, ``runs``, ``run_events``, and
|
||||
``feedback``.
|
||||
|
||||
Constructor takes the session factory directly (same pattern as the
|
||||
other four repositories in ``deerflow.persistence.*``). Callers
|
||||
construct this after ``init_engine_from_config()`` has run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.repositories.base import UserNotFoundError, UserRepository
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
|
||||
class SQLiteUserRepository(UserRepository):
|
||||
"""Async user repository backed by the shared SQLAlchemy engine."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
# ── Converters ────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _row_to_user(row: UserRow) -> User:
|
||||
return User(
|
||||
id=UUID(row.id),
|
||||
email=row.email,
|
||||
password_hash=row.password_hash,
|
||||
system_role=row.system_role, # type: ignore[arg-type]
|
||||
# SQLite loses tzinfo on read; reattach UTC so downstream
|
||||
# code can compare timestamps reliably.
|
||||
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
|
||||
oauth_provider=row.oauth_provider,
|
||||
oauth_id=row.oauth_id,
|
||||
needs_setup=row.needs_setup,
|
||||
token_version=row.token_version,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _user_to_row(user: User) -> UserRow:
|
||||
return UserRow(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
system_role=user.system_role,
|
||||
created_at=user.created_at,
|
||||
oauth_provider=user.oauth_provider,
|
||||
oauth_id=user.oauth_id,
|
||||
needs_setup=user.needs_setup,
|
||||
token_version=user.token_version,
|
||||
)
|
||||
|
||||
# ── CRUD ──────────────────────────────────────────────────────────
|
||||
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
|
||||
row = self._user_to_row(user)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
try:
|
||||
await session.commit()
|
||||
except IntegrityError as exc:
|
||||
await session.rollback()
|
||||
raise ValueError(f"Email already registered: {user.email}") from exc
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(UserRow, user_id)
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
stmt = select(UserRow).where(UserRow.email == email)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(UserRow, str(user.id))
|
||||
if row is None:
|
||||
# Hard fail on concurrent delete: callers (reset_admin,
|
||||
# password change handlers, _ensure_admin_user) all
|
||||
# fetched the user just before this call, so a missing
|
||||
# row here means the row vanished underneath us. Silent
|
||||
# success would let the caller log "password reset" for
|
||||
# a row that no longer exists.
|
||||
raise UserNotFoundError(f"User {user.id} no longer exists")
|
||||
row.email = user.email
|
||||
row.password_hash = user.password_hash
|
||||
row.system_role = user.system_role
|
||||
row.oauth_provider = user.oauth_provider
|
||||
row.oauth_id = user.oauth_id
|
||||
row.needs_setup = user.needs_setup
|
||||
row.token_version = user.token_version
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
async def count_users(self) -> int:
|
||||
stmt = select(func.count()).select_from(UserRow)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
stmt = select(func.count()).select_from(UserRow).where(UserRow.system_role == "admin")
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
@@ -0,0 +1,92 @@
|
||||
"""CLI tool to reset an admin password.
|
||||
|
||||
Usage:
|
||||
python -m app.gateway.auth.reset_admin
|
||||
python -m app.gateway.auth.reset_admin --email admin@example.com
|
||||
|
||||
Writes the new password to ``.deer-flow/admin_initial_credentials.txt``
|
||||
(mode 0600) instead of printing it, so CI / log aggregators never see
|
||||
the cleartext secret.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import secrets
|
||||
import sys
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.gateway.auth.credential_file import write_initial_credentials
|
||||
from app.gateway.auth.password import hash_password
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
|
||||
async def _run(email: str | None) -> int:
|
||||
from deerflow.config import AppConfig
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine_from_config,
|
||||
)
|
||||
|
||||
# CLI entry: load config explicitly at the top, pass down through the closure.
|
||||
config = AppConfig.from_file()
|
||||
await init_engine_from_config(config.database)
|
||||
try:
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
print("Error: persistence engine not available (check config.database).", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
repo = SQLiteUserRepository(sf)
|
||||
|
||||
if email:
|
||||
user = await repo.get_user_by_email(email)
|
||||
else:
|
||||
# Find first admin via direct SELECT — repository does not
|
||||
# expose a "first admin" helper and we do not want to add
|
||||
# one just for this CLI.
|
||||
async with sf() as session:
|
||||
stmt = select(UserRow).where(UserRow.system_role == "admin").limit(1)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
user = None
|
||||
else:
|
||||
user = await repo.get_user_by_id(row.id)
|
||||
|
||||
if user is None:
|
||||
if email:
|
||||
print(f"Error: user '{email}' not found.", file=sys.stderr)
|
||||
else:
|
||||
print("Error: no admin user found.", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
new_password = secrets.token_urlsafe(16)
|
||||
user.password_hash = hash_password(new_password)
|
||||
user.token_version += 1
|
||||
user.needs_setup = True
|
||||
await repo.update_user(user)
|
||||
|
||||
cred_path = write_initial_credentials(user.email, new_password, label="reset")
|
||||
print(f"Password reset for: {user.email}")
|
||||
print(f"Credentials written to: {cred_path} (mode 0600)")
|
||||
print("Next login will require setup (new email + password).")
|
||||
return 0
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Reset admin password")
|
||||
parser.add_argument("--email", help="Admin email (default: first admin found)")
|
||||
args = parser.parse_args()
|
||||
|
||||
exit_code = asyncio.run(_run(args.email))
|
||||
sys.exit(exit_code)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user