refactor(auth): migrate user repository to SQLAlchemy ORM

Move the users table into the shared persistence engine so auth
matches the pattern of threads_meta, runs, run_events, and feedback —
one engine, one session factory, one schema init codepath.

New files
---------
- persistence/user/__init__.py, persistence/user/model.py: UserRow
  ORM class with partial unique index on (oauth_provider, oauth_id)
- Registered in persistence/models/__init__.py so
  Base.metadata.create_all() picks it up

Modified
--------
- auth/repositories/sqlite.py: rewritten as async SQLAlchemy,
  identical constructor pattern to the other four repositories
  (def __init__(self, session_factory) + self._sf = session_factory)
- auth/config.py: drop users_db_path field — storage is configured
  through config.database like every other table
- deps.py/get_local_provider: construct SQLiteUserRepository with
  the shared session factory, fail fast if engine is not initialised
- tests/test_auth.py: rewrite test_sqlite_round_trip_new_fields to
  use the shared engine (init_engine + close_engine in a tempdir)
- tests/test_auth_type_system.py: add per-test autouse fixture that
  spins up a scratch engine and resets deps._cached_* singletons
This commit is contained in:
greatmengqi
2026-04-08 11:49:24 +08:00
parent f0b065bef6
commit ceeccabc98
8 changed files with 254 additions and 216 deletions
+7 -5
View File
@@ -13,17 +13,19 @@ logger = logging.getLogger(__name__)
class AuthConfig(BaseModel):
"""JWT and auth-related configuration. Parsed once at startup."""
"""JWT and auth-related configuration. Parsed once at startup.
Note: the ``users`` table now lives in the shared persistence
database managed by ``deerflow.persistence.engine``. The old
``users_db_path`` config key has been removed — user storage is
configured through ``config.database`` like every other table.
"""
jwt_secret: str = Field(
...,
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
)
token_expiry_days: int = Field(default=7, ge=1, le=30)
users_db_path: str | None = Field(
default=None,
description="Path to users SQLite DB. Defaults to .deer-flow/users.db",
)
oauth_github_client_id: str | None = Field(default=None)
oauth_github_client_secret: str | None = Field(default=None)
+93 -173
View File
@@ -1,196 +1,116 @@
"""SQLite implementation of UserRepository."""
"""SQLAlchemy-backed UserRepository implementation.
import asyncio
import sqlite3
from contextlib import contextmanager
from datetime import UTC, datetime
from pathlib import Path
from typing import Any
Uses the shared async session factory from
``deerflow.persistence.engine`` — the ``users`` table lives in the
same database as ``threads_meta``, ``runs``, ``run_events``, and
``feedback``.
Constructor takes the session factory directly (same pattern as the
other four repositories in ``deerflow.persistence.*``). Callers
construct this after ``init_engine_from_config()`` has run.
"""
from __future__ import annotations
from datetime import UTC
from uuid import UUID
from app.gateway.auth.config import get_auth_config
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.gateway.auth.models import User
from app.gateway.auth.repositories.base import UserRepository
_resolved_db_path: Path | None = None
_table_initialized: bool = False
def _get_users_db_path() -> Path:
"""Get the users database path (resolved and cached once)."""
global _resolved_db_path
if _resolved_db_path is not None:
return _resolved_db_path
config = get_auth_config()
if config.users_db_path:
_resolved_db_path = Path(config.users_db_path)
else:
_resolved_db_path = Path(".deer-flow/users.db")
_resolved_db_path.parent.mkdir(parents=True, exist_ok=True)
return _resolved_db_path
def _get_connection() -> sqlite3.Connection:
"""Get a SQLite connection for the users database."""
db_path = _get_users_db_path()
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
return conn
def _init_users_table(conn: sqlite3.Connection) -> None:
"""Initialize the users table if it doesn't exist."""
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(
"""
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT,
system_role TEXT NOT NULL DEFAULT 'user',
created_at REAL NOT NULL,
oauth_provider TEXT,
oauth_id TEXT,
needs_setup INTEGER NOT NULL DEFAULT 0,
token_version INTEGER NOT NULL DEFAULT 0
)
"""
)
# Add unique constraint for OAuth identity to prevent duplicate social logins
conn.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_oauth_identity
ON users(oauth_provider, oauth_id)
WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL
"""
)
conn.commit()
@contextmanager
def _get_users_conn():
"""Context manager for users database connection."""
global _table_initialized
conn = _get_connection()
try:
if not _table_initialized:
_init_users_table(conn)
_table_initialized = True
yield conn
finally:
conn.close()
from deerflow.persistence.user.model import UserRow
class SQLiteUserRepository(UserRepository):
"""SQLite implementation of UserRepository."""
"""Async user repository backed by the shared SQLAlchemy engine."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
# ── Converters ────────────────────────────────────────────────────
@staticmethod
def _row_to_user(row: UserRow) -> User:
return User(
id=UUID(row.id),
email=row.email,
password_hash=row.password_hash,
system_role=row.system_role, # type: ignore[arg-type]
# SQLite loses tzinfo on read; reattach UTC so downstream
# code can compare timestamps reliably.
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
oauth_provider=row.oauth_provider,
oauth_id=row.oauth_id,
needs_setup=row.needs_setup,
token_version=row.token_version,
)
@staticmethod
def _user_to_row(user: User) -> UserRow:
return UserRow(
id=str(user.id),
email=user.email,
password_hash=user.password_hash,
system_role=user.system_role,
created_at=user.created_at,
oauth_provider=user.oauth_provider,
oauth_id=user.oauth_id,
needs_setup=user.needs_setup,
token_version=user.token_version,
)
# ── CRUD ──────────────────────────────────────────────────────────
async def create_user(self, user: User) -> User:
"""Create a new user in SQLite."""
return await asyncio.to_thread(self._create_user_sync, user)
def _create_user_sync(self, user: User) -> User:
"""Synchronous user creation (runs in thread pool)."""
with _get_users_conn() as conn:
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
row = self._user_to_row(user)
async with self._sf() as session:
session.add(row)
try:
conn.execute(
"""
INSERT INTO users (id, email, password_hash, system_role, created_at, oauth_provider, oauth_id, needs_setup, token_version)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
""",
(
str(user.id),
user.email,
user.password_hash,
user.system_role,
datetime.now(UTC).timestamp(),
user.oauth_provider,
user.oauth_id,
int(user.needs_setup),
user.token_version,
),
)
conn.commit()
except sqlite3.IntegrityError as e:
if "UNIQUE constraint failed: users.email" in str(e):
raise ValueError(f"Email already registered: {user.email}") from e
raise
await session.commit()
except IntegrityError as exc:
await session.rollback()
raise ValueError(f"Email already registered: {user.email}") from exc
return user
async def get_user_by_id(self, user_id: str) -> User | None:
"""Get user by ID from SQLite."""
return await asyncio.to_thread(self._get_user_by_id_sync, user_id)
def _get_user_by_id_sync(self, user_id: str) -> User | None:
"""Synchronous get by ID (runs in thread pool)."""
with _get_users_conn() as conn:
cursor = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,))
row = cursor.fetchone()
if row is None:
return None
return self._row_to_user(dict(row))
async with self._sf() as session:
row = await session.get(UserRow, user_id)
return self._row_to_user(row) if row is not None else None
async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email from SQLite."""
return await asyncio.to_thread(self._get_user_by_email_sync, email)
def _get_user_by_email_sync(self, email: str) -> User | None:
"""Synchronous get by email (runs in thread pool)."""
with _get_users_conn() as conn:
cursor = conn.execute("SELECT * FROM users WHERE email = ?", (email,))
row = cursor.fetchone()
if row is None:
return None
return self._row_to_user(dict(row))
stmt = select(UserRow).where(UserRow.email == email)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None
async def update_user(self, user: User) -> User:
"""Update an existing user in SQLite."""
return await asyncio.to_thread(self._update_user_sync, user)
def _update_user_sync(self, user: User) -> User:
with _get_users_conn() as conn:
conn.execute(
"UPDATE users SET email = ?, password_hash = ?, system_role = ?, oauth_provider = ?, oauth_id = ?, needs_setup = ?, token_version = ? WHERE id = ?",
(user.email, user.password_hash, user.system_role, user.oauth_provider, user.oauth_id, int(user.needs_setup), user.token_version, str(user.id)),
)
conn.commit()
async with self._sf() as session:
row = await session.get(UserRow, str(user.id))
if row is None:
return user
row.email = user.email
row.password_hash = user.password_hash
row.system_role = user.system_role
row.oauth_provider = user.oauth_provider
row.oauth_id = user.oauth_id
row.needs_setup = user.needs_setup
row.token_version = user.token_version
await session.commit()
return user
async def count_users(self) -> int:
"""Return total number of registered users."""
return await asyncio.to_thread(self._count_users_sync)
def _count_users_sync(self) -> int:
with _get_users_conn() as conn:
cursor = conn.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]
stmt = select(func.count()).select_from(UserRow)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID from SQLite."""
return await asyncio.to_thread(self._get_user_by_oauth_sync, provider, oauth_id)
def _get_user_by_oauth_sync(self, provider: str, oauth_id: str) -> User | None:
"""Synchronous get by OAuth (runs in thread pool)."""
with _get_users_conn() as conn:
cursor = conn.execute(
"SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
(provider, oauth_id),
)
row = cursor.fetchone()
if row is None:
return None
return self._row_to_user(dict(row))
@staticmethod
def _row_to_user(row: dict[str, Any]) -> User:
"""Convert a database row to a User model."""
return User(
id=UUID(row["id"]),
email=row["email"],
password_hash=row["password_hash"],
system_role=row["system_role"],
created_at=datetime.fromtimestamp(row["created_at"], tz=UTC),
oauth_provider=row.get("oauth_provider"),
oauth_id=row.get("oauth_id"),
needs_setup=bool(row["needs_setup"]),
token_version=int(row["token_version"]),
)
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
async with self._sf() as session:
result = await session.execute(stmt)
row = result.scalar_one_or_none()
return self._row_to_user(row) if row is not None else None
+10 -2
View File
@@ -142,12 +142,20 @@ _cached_repo: SQLiteUserRepository | None = None
def get_local_provider() -> LocalAuthProvider:
"""Get or create the cached LocalAuthProvider singleton."""
"""Get or create the cached LocalAuthProvider singleton.
Must be called after ``init_engine_from_config()`` — the shared
session factory is required to construct the user repository.
"""
global _cached_local_provider, _cached_repo
if _cached_repo is None:
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.engine import get_session_factory
_cached_repo = SQLiteUserRepository()
sf = get_session_factory()
if sf is None:
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
_cached_repo = SQLiteUserRepository(sf)
if _cached_local_provider is None:
from app.gateway.auth.local_provider import LocalAuthProvider