mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 00:45:57 +00:00
feat(auth): introduce backend auth module
Port RFC-001 authentication core from PR #1728: - JWT token handling (create_access_token, decode_token, TokenPayload) - Password hashing (bcrypt) with verify_password - SQLite UserRepository with base interface - Provider Factory pattern (LocalAuthProvider) - CLI reset_admin tool - Auth-specific errors (AuthErrorCode, TokenError, AuthErrorResponse) Deps: - bcrypt>=4.0.0 - pyjwt>=2.9.0 - email-validator>=2.0.0 - backend/uv.toml pins public PyPI index Tests: 12 pure unit tests (test_auth_config.py, test_auth_errors.py). Scope note: authz.py, test_auth.py, and test_auth_type_system.py are deferred to commit 2 because they depend on middleware and deps wiring that is not yet in place. Commit 1 stays "pure new files only" as the spec mandates.
This commit is contained in:
@@ -0,0 +1,196 @@
|
||||
"""SQLite implementation of UserRepository."""
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
_resolved_db_path: Path | None = None
|
||||
_table_initialized: bool = False
|
||||
|
||||
|
||||
def _get_users_db_path() -> Path:
|
||||
"""Get the users database path (resolved and cached once)."""
|
||||
global _resolved_db_path
|
||||
if _resolved_db_path is not None:
|
||||
return _resolved_db_path
|
||||
config = get_auth_config()
|
||||
if config.users_db_path:
|
||||
_resolved_db_path = Path(config.users_db_path)
|
||||
else:
|
||||
_resolved_db_path = Path(".deer-flow/users.db")
|
||||
_resolved_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return _resolved_db_path
|
||||
|
||||
|
||||
def _get_connection() -> sqlite3.Connection:
|
||||
"""Get a SQLite connection for the users database."""
|
||||
db_path = _get_users_db_path()
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def _init_users_table(conn: sqlite3.Connection) -> None:
|
||||
"""Initialize the users table if it doesn't exist."""
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT,
|
||||
system_role TEXT NOT NULL DEFAULT 'user',
|
||||
created_at REAL NOT NULL,
|
||||
oauth_provider TEXT,
|
||||
oauth_id TEXT,
|
||||
needs_setup INTEGER NOT NULL DEFAULT 0,
|
||||
token_version INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Add unique constraint for OAuth identity to prevent duplicate social logins
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_oauth_identity
|
||||
ON users(oauth_provider, oauth_id)
|
||||
WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _get_users_conn():
|
||||
"""Context manager for users database connection."""
|
||||
global _table_initialized
|
||||
conn = _get_connection()
|
||||
try:
|
||||
if not _table_initialized:
|
||||
_init_users_table(conn)
|
||||
_table_initialized = True
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
class SQLiteUserRepository(UserRepository):
|
||||
"""SQLite implementation of UserRepository."""
|
||||
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Create a new user in SQLite."""
|
||||
return await asyncio.to_thread(self._create_user_sync, user)
|
||||
|
||||
def _create_user_sync(self, user: User) -> User:
|
||||
"""Synchronous user creation (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO users (id, email, password_hash, system_role, created_at, oauth_provider, oauth_id, needs_setup, token_version)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(user.id),
|
||||
user.email,
|
||||
user.password_hash,
|
||||
user.system_role,
|
||||
datetime.now(UTC).timestamp(),
|
||||
user.oauth_provider,
|
||||
user.oauth_id,
|
||||
int(user.needs_setup),
|
||||
user.token_version,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError as e:
|
||||
if "UNIQUE constraint failed: users.email" in str(e):
|
||||
raise ValueError(f"Email already registered: {user.email}") from e
|
||||
raise
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_id_sync, user_id)
|
||||
|
||||
def _get_user_by_id_sync(self, user_id: str) -> User | None:
|
||||
"""Synchronous get by ID (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_email_sync, email)
|
||||
|
||||
def _get_user_by_email_sync(self, email: str) -> User | None:
|
||||
"""Synchronous get by email (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT * FROM users WHERE email = ?", (email,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user in SQLite."""
|
||||
return await asyncio.to_thread(self._update_user_sync, user)
|
||||
|
||||
def _update_user_sync(self, user: User) -> User:
|
||||
with _get_users_conn() as conn:
|
||||
conn.execute(
|
||||
"UPDATE users SET email = ?, password_hash = ?, system_role = ?, oauth_provider = ?, oauth_id = ?, needs_setup = ?, token_version = ? WHERE id = ?",
|
||||
(user.email, user.password_hash, user.system_role, user.oauth_provider, user.oauth_id, int(user.needs_setup), user.token_version, str(user.id)),
|
||||
)
|
||||
conn.commit()
|
||||
return user
|
||||
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
return await asyncio.to_thread(self._count_users_sync)
|
||||
|
||||
def _count_users_sync(self) -> int:
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_oauth_sync, provider, oauth_id)
|
||||
|
||||
def _get_user_by_oauth_sync(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Synchronous get by OAuth (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
(provider, oauth_id),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
@staticmethod
|
||||
def _row_to_user(row: dict[str, Any]) -> User:
|
||||
"""Convert a database row to a User model."""
|
||||
return User(
|
||||
id=UUID(row["id"]),
|
||||
email=row["email"],
|
||||
password_hash=row["password_hash"],
|
||||
system_role=row["system_role"],
|
||||
created_at=datetime.fromtimestamp(row["created_at"], tz=UTC),
|
||||
oauth_provider=row.get("oauth_provider"),
|
||||
oauth_id=row.get("oauth_id"),
|
||||
needs_setup=bool(row["needs_setup"]),
|
||||
token_version=int(row["token_version"]),
|
||||
)
|
||||
Reference in New Issue
Block a user