refactor(auth): migrate user repository to SQLAlchemy ORM

Move the users table into the shared persistence engine so auth
matches the pattern of threads_meta, runs, run_events, and feedback —
one engine, one session factory, one schema init codepath.

New files
---------
- persistence/user/__init__.py, persistence/user/model.py: UserRow
  ORM class with partial unique index on (oauth_provider, oauth_id)
- Registered in persistence/models/__init__.py so
  Base.metadata.create_all() picks it up

Modified
--------
- auth/repositories/sqlite.py: rewritten as async SQLAlchemy,
  identical constructor pattern to the other four repositories
  (def __init__(self, session_factory) + self._sf = session_factory)
- auth/config.py: drop users_db_path field — storage is configured
  through config.database like every other table
- deps.py/get_local_provider: construct SQLiteUserRepository with
  the shared session factory, fail fast if engine is not initialised
- tests/test_auth.py: rewrite test_sqlite_round_trip_new_fields to
  use the shared engine (init_engine + close_engine in a tempdir)
- tests/test_auth_type_system.py: add per-test autouse fixture that
  spins up a scratch engine and resets deps._cached_* singletons
This commit is contained in:
greatmengqi
2026-04-08 11:49:24 +08:00
parent f0b065bef6
commit ceeccabc98
8 changed files with 254 additions and 216 deletions
+7 -5
View File
@@ -13,17 +13,19 @@ logger = logging.getLogger(__name__)
class AuthConfig(BaseModel): class AuthConfig(BaseModel):
"""JWT and auth-related configuration. Parsed once at startup.""" """JWT and auth-related configuration. Parsed once at startup.
Note: the ``users`` table now lives in the shared persistence
database managed by ``deerflow.persistence.engine``. The old
``users_db_path`` config key has been removed — user storage is
configured through ``config.database`` like every other table.
"""
jwt_secret: str = Field( jwt_secret: str = Field(
..., ...,
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.", description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
) )
token_expiry_days: int = Field(default=7, ge=1, le=30) token_expiry_days: int = Field(default=7, ge=1, le=30)
users_db_path: str | None = Field(
default=None,
description="Path to users SQLite DB. Defaults to .deer-flow/users.db",
)
oauth_github_client_id: str | None = Field(default=None) oauth_github_client_id: str | None = Field(default=None)
oauth_github_client_secret: str | None = Field(default=None) oauth_github_client_secret: str | None = Field(default=None)
+93 -173
View File
@@ -1,196 +1,116 @@
"""SQLite implementation of UserRepository.""" """SQLAlchemy-backed UserRepository implementation.
import asyncio Uses the shared async session factory from
import sqlite3 ``deerflow.persistence.engine`` — the ``users`` table lives in the
from contextlib import contextmanager same database as ``threads_meta``, ``runs``, ``run_events``, and
from datetime import UTC, datetime ``feedback``.
from pathlib import Path
from typing import Any Constructor takes the session factory directly (same pattern as the
other four repositories in ``deerflow.persistence.*``). Callers
construct this after ``init_engine_from_config()`` has run.
"""
from __future__ import annotations
from datetime import UTC
from uuid import UUID from uuid import UUID
from app.gateway.auth.config import get_auth_config from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from app.gateway.auth.models import User from app.gateway.auth.models import User
from app.gateway.auth.repositories.base import UserRepository from app.gateway.auth.repositories.base import UserRepository
from deerflow.persistence.user.model import UserRow
_resolved_db_path: Path | None = None
_table_initialized: bool = False
def _get_users_db_path() -> Path:
"""Get the users database path (resolved and cached once)."""
global _resolved_db_path
if _resolved_db_path is not None:
return _resolved_db_path
config = get_auth_config()
if config.users_db_path:
_resolved_db_path = Path(config.users_db_path)
else:
_resolved_db_path = Path(".deer-flow/users.db")
_resolved_db_path.parent.mkdir(parents=True, exist_ok=True)
return _resolved_db_path
def _get_connection() -> sqlite3.Connection:
"""Get a SQLite connection for the users database."""
db_path = _get_users_db_path()
conn = sqlite3.connect(str(db_path))
conn.row_factory = sqlite3.Row
return conn
def _init_users_table(conn: sqlite3.Connection) -> None:
"""Initialize the users table if it doesn't exist."""
conn.execute("PRAGMA journal_mode=WAL")
conn.execute(
"""
CREATE TABLE IF NOT EXISTS users (
id TEXT PRIMARY KEY,
email TEXT UNIQUE NOT NULL,
password_hash TEXT,
system_role TEXT NOT NULL DEFAULT 'user',
created_at REAL NOT NULL,
oauth_provider TEXT,
oauth_id TEXT,
needs_setup INTEGER NOT NULL DEFAULT 0,
token_version INTEGER NOT NULL DEFAULT 0
)
"""
)
# Add unique constraint for OAuth identity to prevent duplicate social logins
conn.execute(
"""
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_oauth_identity
ON users(oauth_provider, oauth_id)
WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL
"""
)
conn.commit()
@contextmanager
def _get_users_conn():
"""Context manager for users database connection."""
global _table_initialized
conn = _get_connection()
try:
if not _table_initialized:
_init_users_table(conn)
_table_initialized = True
yield conn
finally:
conn.close()
class SQLiteUserRepository(UserRepository): class SQLiteUserRepository(UserRepository):
"""SQLite implementation of UserRepository.""" """Async user repository backed by the shared SQLAlchemy engine."""
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
# ── Converters ────────────────────────────────────────────────────
@staticmethod
def _row_to_user(row: UserRow) -> User:
return User(
id=UUID(row.id),
email=row.email,
password_hash=row.password_hash,
system_role=row.system_role, # type: ignore[arg-type]
# SQLite loses tzinfo on read; reattach UTC so downstream
# code can compare timestamps reliably.
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
oauth_provider=row.oauth_provider,
oauth_id=row.oauth_id,
needs_setup=row.needs_setup,
token_version=row.token_version,
)
@staticmethod
def _user_to_row(user: User) -> UserRow:
return UserRow(
id=str(user.id),
email=user.email,
password_hash=user.password_hash,
system_role=user.system_role,
created_at=user.created_at,
oauth_provider=user.oauth_provider,
oauth_id=user.oauth_id,
needs_setup=user.needs_setup,
token_version=user.token_version,
)
# ── CRUD ──────────────────────────────────────────────────────────
async def create_user(self, user: User) -> User: async def create_user(self, user: User) -> User:
"""Create a new user in SQLite.""" """Insert a new user. Raises ``ValueError`` on duplicate email."""
return await asyncio.to_thread(self._create_user_sync, user) row = self._user_to_row(user)
async with self._sf() as session:
def _create_user_sync(self, user: User) -> User: session.add(row)
"""Synchronous user creation (runs in thread pool)."""
with _get_users_conn() as conn:
try: try:
conn.execute( await session.commit()
""" except IntegrityError as exc:
INSERT INTO users (id, email, password_hash, system_role, created_at, oauth_provider, oauth_id, needs_setup, token_version) await session.rollback()
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) raise ValueError(f"Email already registered: {user.email}") from exc
""",
(
str(user.id),
user.email,
user.password_hash,
user.system_role,
datetime.now(UTC).timestamp(),
user.oauth_provider,
user.oauth_id,
int(user.needs_setup),
user.token_version,
),
)
conn.commit()
except sqlite3.IntegrityError as e:
if "UNIQUE constraint failed: users.email" in str(e):
raise ValueError(f"Email already registered: {user.email}") from e
raise
return user return user
async def get_user_by_id(self, user_id: str) -> User | None: async def get_user_by_id(self, user_id: str) -> User | None:
"""Get user by ID from SQLite.""" async with self._sf() as session:
return await asyncio.to_thread(self._get_user_by_id_sync, user_id) row = await session.get(UserRow, user_id)
return self._row_to_user(row) if row is not None else None
def _get_user_by_id_sync(self, user_id: str) -> User | None:
"""Synchronous get by ID (runs in thread pool)."""
with _get_users_conn() as conn:
cursor = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,))
row = cursor.fetchone()
if row is None:
return None
return self._row_to_user(dict(row))
async def get_user_by_email(self, email: str) -> User | None: async def get_user_by_email(self, email: str) -> User | None:
"""Get user by email from SQLite.""" stmt = select(UserRow).where(UserRow.email == email)
return await asyncio.to_thread(self._get_user_by_email_sync, email) async with self._sf() as session:
result = await session.execute(stmt)
def _get_user_by_email_sync(self, email: str) -> User | None: row = result.scalar_one_or_none()
"""Synchronous get by email (runs in thread pool).""" return self._row_to_user(row) if row is not None else None
with _get_users_conn() as conn:
cursor = conn.execute("SELECT * FROM users WHERE email = ?", (email,))
row = cursor.fetchone()
if row is None:
return None
return self._row_to_user(dict(row))
async def update_user(self, user: User) -> User: async def update_user(self, user: User) -> User:
"""Update an existing user in SQLite.""" async with self._sf() as session:
return await asyncio.to_thread(self._update_user_sync, user) row = await session.get(UserRow, str(user.id))
if row is None:
def _update_user_sync(self, user: User) -> User: return user
with _get_users_conn() as conn: row.email = user.email
conn.execute( row.password_hash = user.password_hash
"UPDATE users SET email = ?, password_hash = ?, system_role = ?, oauth_provider = ?, oauth_id = ?, needs_setup = ?, token_version = ? WHERE id = ?", row.system_role = user.system_role
(user.email, user.password_hash, user.system_role, user.oauth_provider, user.oauth_id, int(user.needs_setup), user.token_version, str(user.id)), row.oauth_provider = user.oauth_provider
) row.oauth_id = user.oauth_id
conn.commit() row.needs_setup = user.needs_setup
row.token_version = user.token_version
await session.commit()
return user return user
async def count_users(self) -> int: async def count_users(self) -> int:
"""Return total number of registered users.""" stmt = select(func.count()).select_from(UserRow)
return await asyncio.to_thread(self._count_users_sync) async with self._sf() as session:
return await session.scalar(stmt) or 0
def _count_users_sync(self) -> int:
with _get_users_conn() as conn:
cursor = conn.execute("SELECT COUNT(*) FROM users")
return cursor.fetchone()[0]
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
"""Get user by OAuth provider and ID from SQLite.""" stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
return await asyncio.to_thread(self._get_user_by_oauth_sync, provider, oauth_id) async with self._sf() as session:
result = await session.execute(stmt)
def _get_user_by_oauth_sync(self, provider: str, oauth_id: str) -> User | None: row = result.scalar_one_or_none()
"""Synchronous get by OAuth (runs in thread pool).""" return self._row_to_user(row) if row is not None else None
with _get_users_conn() as conn:
cursor = conn.execute(
"SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
(provider, oauth_id),
)
row = cursor.fetchone()
if row is None:
return None
return self._row_to_user(dict(row))
@staticmethod
def _row_to_user(row: dict[str, Any]) -> User:
"""Convert a database row to a User model."""
return User(
id=UUID(row["id"]),
email=row["email"],
password_hash=row["password_hash"],
system_role=row["system_role"],
created_at=datetime.fromtimestamp(row["created_at"], tz=UTC),
oauth_provider=row.get("oauth_provider"),
oauth_id=row.get("oauth_id"),
needs_setup=bool(row["needs_setup"]),
token_version=int(row["token_version"]),
)
+10 -2
View File
@@ -142,12 +142,20 @@ _cached_repo: SQLiteUserRepository | None = None
def get_local_provider() -> LocalAuthProvider: def get_local_provider() -> LocalAuthProvider:
"""Get or create the cached LocalAuthProvider singleton.""" """Get or create the cached LocalAuthProvider singleton.
Must be called after ``init_engine_from_config()`` — the shared
session factory is required to construct the user repository.
"""
global _cached_local_provider, _cached_repo global _cached_local_provider, _cached_repo
if _cached_repo is None: if _cached_repo is None:
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.engine import get_session_factory
_cached_repo = SQLiteUserRepository() sf = get_session_factory()
if sf is None:
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
_cached_repo = SQLiteUserRepository(sf)
if _cached_local_provider is None: if _cached_local_provider is None:
from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.local_provider import LocalAuthProvider
@@ -7,6 +7,7 @@ The actual ORM classes have moved to entity-specific subpackages:
- ``deerflow.persistence.thread_meta`` - ``deerflow.persistence.thread_meta``
- ``deerflow.persistence.run`` - ``deerflow.persistence.run``
- ``deerflow.persistence.feedback`` - ``deerflow.persistence.feedback``
- ``deerflow.persistence.user``
``RunEventRow`` remains in ``deerflow.persistence.models.run_event`` because ``RunEventRow`` remains in ``deerflow.persistence.models.run_event`` because
its storage implementation lives in ``deerflow.runtime.events.store.db`` and its storage implementation lives in ``deerflow.runtime.events.store.db`` and
@@ -17,5 +18,6 @@ from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.persistence.models.run_event import RunEventRow from deerflow.persistence.models.run_event import RunEventRow
from deerflow.persistence.run.model import RunRow from deerflow.persistence.run.model import RunRow
from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.persistence.user.model import UserRow
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow"] __all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow", "UserRow"]
@@ -0,0 +1,12 @@
"""User storage subpackage.
Holds the ORM model for the ``users`` table. The concrete repository
implementation (``SQLiteUserRepository``) lives in the app layer
(``app.gateway.auth.repositories.sqlite``) because it converts
between the ORM row and the auth module's pydantic ``User`` class.
This keeps the harness package free of any dependency on app code.
"""
from deerflow.persistence.user.model import UserRow
__all__ = ["UserRow"]
@@ -0,0 +1,59 @@
"""ORM model for the users table.
Lives in the harness persistence package so it is picked up by
``Base.metadata.create_all()`` alongside ``threads_meta``, ``runs``,
``run_events``, and ``feedback``. Using the shared engine means:
- One SQLite/Postgres database, one connection pool
- One schema initialisation codepath
- Consistent async sessions across auth and persistence reads
"""
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import Boolean, DateTime, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column
from deerflow.persistence.base import Base
class UserRow(Base):
__tablename__ = "users"
# UUIDs are stored as 36-char strings for cross-backend portability.
id: Mapped[str] = mapped_column(String(36), primary_key=True)
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
password_hash: Mapped[str | None] = mapped_column(String(128), nullable=True)
# "admin" | "user" — kept as plain string to avoid ALTER TABLE pain
# when new roles are introduced.
system_role: Mapped[str] = mapped_column(String(16), nullable=False, default="user")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True),
nullable=False,
default=lambda: datetime.now(UTC),
)
# OAuth linkage (optional). A partial unique index enforces one
# account per (provider, oauth_id) pair, leaving NULL/NULL rows
# unconstrained so plain password accounts can coexist.
oauth_provider: Mapped[str | None] = mapped_column(String(32), nullable=True)
oauth_id: Mapped[str | None] = mapped_column(String(128), nullable=True)
# Auth lifecycle flags
needs_setup: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
token_version: Mapped[int] = mapped_column(nullable=False, default=0)
__table_args__ = (
Index(
"idx_users_oauth_identity",
"oauth_provider",
"oauth_id",
unique=True,
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
),
)
+44 -35
View File
@@ -262,47 +262,56 @@ def test_user_model_needs_setup_true():
def test_sqlite_round_trip_new_fields(): def test_sqlite_round_trip_new_fields():
"""needs_setup and token_version survive create → read round-trip.""" """needs_setup and token_version survive create → read round-trip.
Uses the shared persistence engine (same one threads_meta, runs,
run_events, and feedback use). The old separate .deer-flow/users.db
file is gone.
"""
import asyncio import asyncio
import os
import tempfile import tempfile
from pathlib import Path
from app.gateway.auth.repositories import sqlite as sqlite_mod from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
with tempfile.TemporaryDirectory() as tmpdir: async def _run() -> None:
db_path = os.path.join(tmpdir, "test_users.db") from deerflow.persistence.engine import (
old_path = sqlite_mod._resolved_db_path close_engine,
old_init = sqlite_mod._table_initialized get_session_factory,
sqlite_mod._resolved_db_path = Path(db_path) init_engine,
sqlite_mod._table_initialized = False )
try:
repo = sqlite_mod.SQLiteUserRepository()
user = User(
email="setup@test.com",
password_hash="fakehash",
system_role="admin",
needs_setup=True,
token_version=3,
)
created = asyncio.run(repo.create_user(user))
assert created.needs_setup is True
assert created.token_version == 3
fetched = asyncio.run(repo.get_user_by_email("setup@test.com")) with tempfile.TemporaryDirectory() as tmpdir:
assert fetched is not None url = f"sqlite+aiosqlite:///{tmpdir}/scratch.db"
assert fetched.needs_setup is True await init_engine("sqlite", url=url, sqlite_dir=tmpdir)
assert fetched.token_version == 3 try:
repo = SQLiteUserRepository(get_session_factory())
user = User(
email="setup@test.com",
password_hash="fakehash",
system_role="admin",
needs_setup=True,
token_version=3,
)
created = await repo.create_user(user)
assert created.needs_setup is True
assert created.token_version == 3
fetched.needs_setup = False fetched = await repo.get_user_by_email("setup@test.com")
fetched.token_version = 4 assert fetched is not None
asyncio.run(repo.update_user(fetched)) assert fetched.needs_setup is True
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id))) assert fetched.token_version == 3
assert refetched.needs_setup is False
assert refetched.token_version == 4 fetched.needs_setup = False
finally: fetched.token_version = 4
sqlite_mod._resolved_db_path = old_path await repo.update_user(fetched)
sqlite_mod._table_initialized = old_init refetched = await repo.get_user_by_id(str(fetched.id))
assert refetched is not None
assert refetched.needs_setup is False
assert refetched.token_version == 4
finally:
await close_engine()
asyncio.run(_run())
# ── Token Versioning ─────────────────────────────────────────────────────── # ── Token Versioning ───────────────────────────────────────────────────────
+26
View File
@@ -32,6 +32,32 @@ from app.gateway.csrf_middleware import (
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32" _TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
@pytest.fixture(autouse=True)
def _persistence_engine(tmp_path):
"""Initialise a per-test SQLite engine + reset cached provider singletons.
The auth tests call real HTTP handlers that go through
``SQLiteUserRepository`` → ``get_session_factory``. Each test gets
a fresh DB plus a clean ``deps._cached_*`` so the cached provider
does not hold a dangling reference to the previous test's engine.
"""
import asyncio
from app.gateway import deps
from deerflow.persistence.engine import close_engine, init_engine
url = f"sqlite+aiosqlite:///{tmp_path}/auth_types.db"
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
deps._cached_local_provider = None
deps._cached_repo = None
try:
yield
finally:
deps._cached_local_provider = None
deps._cached_repo = None
asyncio.run(close_engine())
def _setup_config(): def _setup_config():
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET)) set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))