mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 07:26:50 +00:00
03c3b18565
Port RFC-001 authentication core from PR #1728: - JWT token handling (create_access_token, decode_token, TokenPayload) - Password hashing (bcrypt) with verify_password - SQLite UserRepository with base interface - Provider Factory pattern (LocalAuthProvider) - CLI reset_admin tool - Auth-specific errors (AuthErrorCode, TokenError, AuthErrorResponse) Deps: - bcrypt>=4.0.0 - pyjwt>=2.9.0 - email-validator>=2.0.0 - backend/uv.toml pins public PyPI index Tests: 12 pure unit tests (test_auth_config.py, test_auth_errors.py). Scope note: authz.py, test_auth.py, and test_auth_type_system.py are deferred to commit 2 because they depend on middleware and deps wiring that is not yet in place. Commit 1 stays "pure new files only" as the spec mandates.
197 lines
7.2 KiB
Python
197 lines
7.2 KiB
Python
"""SQLite implementation of UserRepository."""
|
|
|
|
import asyncio
|
|
import sqlite3
|
|
from contextlib import contextmanager
|
|
from datetime import UTC, datetime
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from uuid import UUID
|
|
|
|
from app.gateway.auth.config import get_auth_config
|
|
from app.gateway.auth.models import User
|
|
from app.gateway.auth.repositories.base import UserRepository
|
|
|
|
_resolved_db_path: Path | None = None
|
|
_table_initialized: bool = False
|
|
|
|
|
|
def _get_users_db_path() -> Path:
|
|
"""Get the users database path (resolved and cached once)."""
|
|
global _resolved_db_path
|
|
if _resolved_db_path is not None:
|
|
return _resolved_db_path
|
|
config = get_auth_config()
|
|
if config.users_db_path:
|
|
_resolved_db_path = Path(config.users_db_path)
|
|
else:
|
|
_resolved_db_path = Path(".deer-flow/users.db")
|
|
_resolved_db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
return _resolved_db_path
|
|
|
|
|
|
def _get_connection() -> sqlite3.Connection:
|
|
"""Get a SQLite connection for the users database."""
|
|
db_path = _get_users_db_path()
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.row_factory = sqlite3.Row
|
|
return conn
|
|
|
|
|
|
def _init_users_table(conn: sqlite3.Connection) -> None:
|
|
"""Initialize the users table if it doesn't exist."""
|
|
conn.execute("PRAGMA journal_mode=WAL")
|
|
conn.execute(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS users (
|
|
id TEXT PRIMARY KEY,
|
|
email TEXT UNIQUE NOT NULL,
|
|
password_hash TEXT,
|
|
system_role TEXT NOT NULL DEFAULT 'user',
|
|
created_at REAL NOT NULL,
|
|
oauth_provider TEXT,
|
|
oauth_id TEXT,
|
|
needs_setup INTEGER NOT NULL DEFAULT 0,
|
|
token_version INTEGER NOT NULL DEFAULT 0
|
|
)
|
|
"""
|
|
)
|
|
# Add unique constraint for OAuth identity to prevent duplicate social logins
|
|
conn.execute(
|
|
"""
|
|
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_oauth_identity
|
|
ON users(oauth_provider, oauth_id)
|
|
WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL
|
|
"""
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
@contextmanager
|
|
def _get_users_conn():
|
|
"""Context manager for users database connection."""
|
|
global _table_initialized
|
|
conn = _get_connection()
|
|
try:
|
|
if not _table_initialized:
|
|
_init_users_table(conn)
|
|
_table_initialized = True
|
|
yield conn
|
|
finally:
|
|
conn.close()
|
|
|
|
|
|
class SQLiteUserRepository(UserRepository):
|
|
"""SQLite implementation of UserRepository."""
|
|
|
|
async def create_user(self, user: User) -> User:
|
|
"""Create a new user in SQLite."""
|
|
return await asyncio.to_thread(self._create_user_sync, user)
|
|
|
|
def _create_user_sync(self, user: User) -> User:
|
|
"""Synchronous user creation (runs in thread pool)."""
|
|
with _get_users_conn() as conn:
|
|
try:
|
|
conn.execute(
|
|
"""
|
|
INSERT INTO users (id, email, password_hash, system_role, created_at, oauth_provider, oauth_id, needs_setup, token_version)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
|
""",
|
|
(
|
|
str(user.id),
|
|
user.email,
|
|
user.password_hash,
|
|
user.system_role,
|
|
datetime.now(UTC).timestamp(),
|
|
user.oauth_provider,
|
|
user.oauth_id,
|
|
int(user.needs_setup),
|
|
user.token_version,
|
|
),
|
|
)
|
|
conn.commit()
|
|
except sqlite3.IntegrityError as e:
|
|
if "UNIQUE constraint failed: users.email" in str(e):
|
|
raise ValueError(f"Email already registered: {user.email}") from e
|
|
raise
|
|
return user
|
|
|
|
async def get_user_by_id(self, user_id: str) -> User | None:
|
|
"""Get user by ID from SQLite."""
|
|
return await asyncio.to_thread(self._get_user_by_id_sync, user_id)
|
|
|
|
def _get_user_by_id_sync(self, user_id: str) -> User | None:
|
|
"""Synchronous get by ID (runs in thread pool)."""
|
|
with _get_users_conn() as conn:
|
|
cursor = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
return None
|
|
return self._row_to_user(dict(row))
|
|
|
|
async def get_user_by_email(self, email: str) -> User | None:
|
|
"""Get user by email from SQLite."""
|
|
return await asyncio.to_thread(self._get_user_by_email_sync, email)
|
|
|
|
def _get_user_by_email_sync(self, email: str) -> User | None:
|
|
"""Synchronous get by email (runs in thread pool)."""
|
|
with _get_users_conn() as conn:
|
|
cursor = conn.execute("SELECT * FROM users WHERE email = ?", (email,))
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
return None
|
|
return self._row_to_user(dict(row))
|
|
|
|
async def update_user(self, user: User) -> User:
|
|
"""Update an existing user in SQLite."""
|
|
return await asyncio.to_thread(self._update_user_sync, user)
|
|
|
|
def _update_user_sync(self, user: User) -> User:
|
|
with _get_users_conn() as conn:
|
|
conn.execute(
|
|
"UPDATE users SET email = ?, password_hash = ?, system_role = ?, oauth_provider = ?, oauth_id = ?, needs_setup = ?, token_version = ? WHERE id = ?",
|
|
(user.email, user.password_hash, user.system_role, user.oauth_provider, user.oauth_id, int(user.needs_setup), user.token_version, str(user.id)),
|
|
)
|
|
conn.commit()
|
|
return user
|
|
|
|
async def count_users(self) -> int:
|
|
"""Return total number of registered users."""
|
|
return await asyncio.to_thread(self._count_users_sync)
|
|
|
|
def _count_users_sync(self) -> int:
|
|
with _get_users_conn() as conn:
|
|
cursor = conn.execute("SELECT COUNT(*) FROM users")
|
|
return cursor.fetchone()[0]
|
|
|
|
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
|
"""Get user by OAuth provider and ID from SQLite."""
|
|
return await asyncio.to_thread(self._get_user_by_oauth_sync, provider, oauth_id)
|
|
|
|
def _get_user_by_oauth_sync(self, provider: str, oauth_id: str) -> User | None:
|
|
"""Synchronous get by OAuth (runs in thread pool)."""
|
|
with _get_users_conn() as conn:
|
|
cursor = conn.execute(
|
|
"SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
|
(provider, oauth_id),
|
|
)
|
|
row = cursor.fetchone()
|
|
if row is None:
|
|
return None
|
|
return self._row_to_user(dict(row))
|
|
|
|
@staticmethod
|
|
def _row_to_user(row: dict[str, Any]) -> User:
|
|
"""Convert a database row to a User model."""
|
|
return User(
|
|
id=UUID(row["id"]),
|
|
email=row["email"],
|
|
password_hash=row["password_hash"],
|
|
system_role=row["system_role"],
|
|
created_at=datetime.fromtimestamp(row["created_at"], tz=UTC),
|
|
oauth_provider=row.get("oauth_provider"),
|
|
oauth_id=row.get("oauth_id"),
|
|
needs_setup=bool(row["needs_setup"]),
|
|
token_version=int(row["token_version"]),
|
|
)
|