ceeccabc98
Move the users table into the shared persistence engine so auth matches the pattern of threads_meta, runs, run_events, and feedback — one engine, one session factory, one schema init codepath. New files --------- - persistence/user/__init__.py, persistence/user/model.py: UserRow ORM class with partial unique index on (oauth_provider, oauth_id) - Registered in persistence/models/__init__.py so Base.metadata.create_all() picks it up Modified -------- - auth/repositories/sqlite.py: rewritten as async SQLAlchemy, identical constructor pattern to the other four repositories (def __init__(self, session_factory) + self._sf = session_factory) - auth/config.py: drop users_db_path field — storage is configured through config.database like every other table - deps.py/get_local_provider: construct SQLiteUserRepository with the shared session factory, fail fast if engine is not initialised - tests/test_auth.py: rewrite test_sqlite_round_trip_new_fields to use the shared engine (init_engine + close_engine in a tempdir) - tests/test_auth_type_system.py: add per-test autouse fixture that spins up a scratch engine and resets deps._cached_* singletons
117 lines
4.7 KiB
Python
117 lines
4.7 KiB
Python
"""SQLAlchemy-backed UserRepository implementation.
|
|
|
|
Uses the shared async session factory from
|
|
``deerflow.persistence.engine`` — the ``users`` table lives in the
|
|
same database as ``threads_meta``, ``runs``, ``run_events``, and
|
|
``feedback``.
|
|
|
|
Constructor takes the session factory directly (same pattern as the
|
|
other four repositories in ``deerflow.persistence.*``). Callers
|
|
construct this after ``init_engine_from_config()`` has run.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import UTC
|
|
from uuid import UUID
|
|
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
|
|
from app.gateway.auth.models import User
|
|
from app.gateway.auth.repositories.base import UserRepository
|
|
from deerflow.persistence.user.model import UserRow
|
|
|
|
|
|
class SQLiteUserRepository(UserRepository):
|
|
"""Async user repository backed by the shared SQLAlchemy engine."""
|
|
|
|
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
|
self._sf = session_factory
|
|
|
|
# ── Converters ────────────────────────────────────────────────────
|
|
|
|
@staticmethod
|
|
def _row_to_user(row: UserRow) -> User:
|
|
return User(
|
|
id=UUID(row.id),
|
|
email=row.email,
|
|
password_hash=row.password_hash,
|
|
system_role=row.system_role, # type: ignore[arg-type]
|
|
# SQLite loses tzinfo on read; reattach UTC so downstream
|
|
# code can compare timestamps reliably.
|
|
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
|
|
oauth_provider=row.oauth_provider,
|
|
oauth_id=row.oauth_id,
|
|
needs_setup=row.needs_setup,
|
|
token_version=row.token_version,
|
|
)
|
|
|
|
@staticmethod
|
|
def _user_to_row(user: User) -> UserRow:
|
|
return UserRow(
|
|
id=str(user.id),
|
|
email=user.email,
|
|
password_hash=user.password_hash,
|
|
system_role=user.system_role,
|
|
created_at=user.created_at,
|
|
oauth_provider=user.oauth_provider,
|
|
oauth_id=user.oauth_id,
|
|
needs_setup=user.needs_setup,
|
|
token_version=user.token_version,
|
|
)
|
|
|
|
# ── CRUD ──────────────────────────────────────────────────────────
|
|
|
|
async def create_user(self, user: User) -> User:
|
|
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
|
|
row = self._user_to_row(user)
|
|
async with self._sf() as session:
|
|
session.add(row)
|
|
try:
|
|
await session.commit()
|
|
except IntegrityError as exc:
|
|
await session.rollback()
|
|
raise ValueError(f"Email already registered: {user.email}") from exc
|
|
return user
|
|
|
|
async def get_user_by_id(self, user_id: str) -> User | None:
|
|
async with self._sf() as session:
|
|
row = await session.get(UserRow, user_id)
|
|
return self._row_to_user(row) if row is not None else None
|
|
|
|
async def get_user_by_email(self, email: str) -> User | None:
|
|
stmt = select(UserRow).where(UserRow.email == email)
|
|
async with self._sf() as session:
|
|
result = await session.execute(stmt)
|
|
row = result.scalar_one_or_none()
|
|
return self._row_to_user(row) if row is not None else None
|
|
|
|
async def update_user(self, user: User) -> User:
|
|
async with self._sf() as session:
|
|
row = await session.get(UserRow, str(user.id))
|
|
if row is None:
|
|
return user
|
|
row.email = user.email
|
|
row.password_hash = user.password_hash
|
|
row.system_role = user.system_role
|
|
row.oauth_provider = user.oauth_provider
|
|
row.oauth_id = user.oauth_id
|
|
row.needs_setup = user.needs_setup
|
|
row.token_version = user.token_version
|
|
await session.commit()
|
|
return user
|
|
|
|
async def count_users(self) -> int:
|
|
stmt = select(func.count()).select_from(UserRow)
|
|
async with self._sf() as session:
|
|
return await session.scalar(stmt) or 0
|
|
|
|
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
|
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
|
async with self._sf() as session:
|
|
result = await session.execute(stmt)
|
|
row = result.scalar_one_or_none()
|
|
return self._row_to_user(row) if row is not None else None
|