mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
refactor(auth): migrate user repository to SQLAlchemy ORM
Move the users table into the shared persistence engine so auth matches the pattern of threads_meta, runs, run_events, and feedback — one engine, one session factory, one schema init codepath. New files --------- - persistence/user/__init__.py, persistence/user/model.py: UserRow ORM class with partial unique index on (oauth_provider, oauth_id) - Registered in persistence/models/__init__.py so Base.metadata.create_all() picks it up Modified -------- - auth/repositories/sqlite.py: rewritten as async SQLAlchemy, identical constructor pattern to the other four repositories (def __init__(self, session_factory) + self._sf = session_factory) - auth/config.py: drop users_db_path field — storage is configured through config.database like every other table - deps.py/get_local_provider: construct SQLiteUserRepository with the shared session factory, fail fast if engine is not initialised - tests/test_auth.py: rewrite test_sqlite_round_trip_new_fields to use the shared engine (init_engine + close_engine in a tempdir) - tests/test_auth_type_system.py: add per-test autouse fixture that spins up a scratch engine and resets deps._cached_* singletons
This commit is contained in:
@@ -13,17 +13,19 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""JWT and auth-related configuration. Parsed once at startup."""
|
||||
"""JWT and auth-related configuration. Parsed once at startup.
|
||||
|
||||
Note: the ``users`` table now lives in the shared persistence
|
||||
database managed by ``deerflow.persistence.engine``. The old
|
||||
``users_db_path`` config key has been removed — user storage is
|
||||
configured through ``config.database`` like every other table.
|
||||
"""
|
||||
|
||||
jwt_secret: str = Field(
|
||||
...,
|
||||
description="Secret key for JWT signing. MUST be set via AUTH_JWT_SECRET.",
|
||||
)
|
||||
token_expiry_days: int = Field(default=7, ge=1, le=30)
|
||||
users_db_path: str | None = Field(
|
||||
default=None,
|
||||
description="Path to users SQLite DB. Defaults to .deer-flow/users.db",
|
||||
)
|
||||
oauth_github_client_id: str | None = Field(default=None)
|
||||
oauth_github_client_secret: str | None = Field(default=None)
|
||||
|
||||
|
||||
@@ -1,196 +1,116 @@
|
||||
"""SQLite implementation of UserRepository."""
|
||||
"""SQLAlchemy-backed UserRepository implementation.
|
||||
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from contextlib import contextmanager
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
Uses the shared async session factory from
|
||||
``deerflow.persistence.engine`` — the ``users`` table lives in the
|
||||
same database as ``threads_meta``, ``runs``, ``run_events``, and
|
||||
``feedback``.
|
||||
|
||||
Constructor takes the session factory directly (same pattern as the
|
||||
other four repositories in ``deerflow.persistence.*``). Callers
|
||||
construct this after ``init_engine_from_config()`` has run.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC
|
||||
from uuid import UUID
|
||||
|
||||
from app.gateway.auth.config import get_auth_config
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.auth.repositories.base import UserRepository
|
||||
|
||||
_resolved_db_path: Path | None = None
|
||||
_table_initialized: bool = False
|
||||
|
||||
|
||||
def _get_users_db_path() -> Path:
|
||||
"""Get the users database path (resolved and cached once)."""
|
||||
global _resolved_db_path
|
||||
if _resolved_db_path is not None:
|
||||
return _resolved_db_path
|
||||
config = get_auth_config()
|
||||
if config.users_db_path:
|
||||
_resolved_db_path = Path(config.users_db_path)
|
||||
else:
|
||||
_resolved_db_path = Path(".deer-flow/users.db")
|
||||
_resolved_db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return _resolved_db_path
|
||||
|
||||
|
||||
def _get_connection() -> sqlite3.Connection:
|
||||
"""Get a SQLite connection for the users database."""
|
||||
db_path = _get_users_db_path()
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.row_factory = sqlite3.Row
|
||||
return conn
|
||||
|
||||
|
||||
def _init_users_table(conn: sqlite3.Connection) -> None:
|
||||
"""Initialize the users table if it doesn't exist."""
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id TEXT PRIMARY KEY,
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT,
|
||||
system_role TEXT NOT NULL DEFAULT 'user',
|
||||
created_at REAL NOT NULL,
|
||||
oauth_provider TEXT,
|
||||
oauth_id TEXT,
|
||||
needs_setup INTEGER NOT NULL DEFAULT 0,
|
||||
token_version INTEGER NOT NULL DEFAULT 0
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Add unique constraint for OAuth identity to prevent duplicate social logins
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS idx_users_oauth_identity
|
||||
ON users(oauth_provider, oauth_id)
|
||||
WHERE oauth_provider IS NOT NULL AND oauth_id IS NOT NULL
|
||||
"""
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _get_users_conn():
|
||||
"""Context manager for users database connection."""
|
||||
global _table_initialized
|
||||
conn = _get_connection()
|
||||
try:
|
||||
if not _table_initialized:
|
||||
_init_users_table(conn)
|
||||
_table_initialized = True
|
||||
yield conn
|
||||
finally:
|
||||
conn.close()
|
||||
from deerflow.persistence.user.model import UserRow
|
||||
|
||||
|
||||
class SQLiteUserRepository(UserRepository):
|
||||
"""SQLite implementation of UserRepository."""
|
||||
"""Async user repository backed by the shared SQLAlchemy engine."""
|
||||
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
# ── Converters ────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _row_to_user(row: UserRow) -> User:
|
||||
return User(
|
||||
id=UUID(row.id),
|
||||
email=row.email,
|
||||
password_hash=row.password_hash,
|
||||
system_role=row.system_role, # type: ignore[arg-type]
|
||||
# SQLite loses tzinfo on read; reattach UTC so downstream
|
||||
# code can compare timestamps reliably.
|
||||
created_at=row.created_at if row.created_at.tzinfo else row.created_at.replace(tzinfo=UTC),
|
||||
oauth_provider=row.oauth_provider,
|
||||
oauth_id=row.oauth_id,
|
||||
needs_setup=row.needs_setup,
|
||||
token_version=row.token_version,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _user_to_row(user: User) -> UserRow:
|
||||
return UserRow(
|
||||
id=str(user.id),
|
||||
email=user.email,
|
||||
password_hash=user.password_hash,
|
||||
system_role=user.system_role,
|
||||
created_at=user.created_at,
|
||||
oauth_provider=user.oauth_provider,
|
||||
oauth_id=user.oauth_id,
|
||||
needs_setup=user.needs_setup,
|
||||
token_version=user.token_version,
|
||||
)
|
||||
|
||||
# ── CRUD ──────────────────────────────────────────────────────────
|
||||
|
||||
async def create_user(self, user: User) -> User:
|
||||
"""Create a new user in SQLite."""
|
||||
return await asyncio.to_thread(self._create_user_sync, user)
|
||||
|
||||
def _create_user_sync(self, user: User) -> User:
|
||||
"""Synchronous user creation (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
"""Insert a new user. Raises ``ValueError`` on duplicate email."""
|
||||
row = self._user_to_row(user)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
try:
|
||||
conn.execute(
|
||||
"""
|
||||
INSERT INTO users (id, email, password_hash, system_role, created_at, oauth_provider, oauth_id, needs_setup, token_version)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
str(user.id),
|
||||
user.email,
|
||||
user.password_hash,
|
||||
user.system_role,
|
||||
datetime.now(UTC).timestamp(),
|
||||
user.oauth_provider,
|
||||
user.oauth_id,
|
||||
int(user.needs_setup),
|
||||
user.token_version,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
except sqlite3.IntegrityError as e:
|
||||
if "UNIQUE constraint failed: users.email" in str(e):
|
||||
raise ValueError(f"Email already registered: {user.email}") from e
|
||||
raise
|
||||
await session.commit()
|
||||
except IntegrityError as exc:
|
||||
await session.rollback()
|
||||
raise ValueError(f"Email already registered: {user.email}") from exc
|
||||
return user
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
"""Get user by ID from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_id_sync, user_id)
|
||||
|
||||
def _get_user_by_id_sync(self, user_id: str) -> User | None:
|
||||
"""Synchronous get by ID (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT * FROM users WHERE id = ?", (user_id,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
async with self._sf() as session:
|
||||
row = await session.get(UserRow, user_id)
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
"""Get user by email from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_email_sync, email)
|
||||
|
||||
def _get_user_by_email_sync(self, email: str) -> User | None:
|
||||
"""Synchronous get by email (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT * FROM users WHERE email = ?", (email,))
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
stmt = select(UserRow).where(UserRow.email == email)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
|
||||
async def update_user(self, user: User) -> User:
|
||||
"""Update an existing user in SQLite."""
|
||||
return await asyncio.to_thread(self._update_user_sync, user)
|
||||
|
||||
def _update_user_sync(self, user: User) -> User:
|
||||
with _get_users_conn() as conn:
|
||||
conn.execute(
|
||||
"UPDATE users SET email = ?, password_hash = ?, system_role = ?, oauth_provider = ?, oauth_id = ?, needs_setup = ?, token_version = ? WHERE id = ?",
|
||||
(user.email, user.password_hash, user.system_role, user.oauth_provider, user.oauth_id, int(user.needs_setup), user.token_version, str(user.id)),
|
||||
)
|
||||
conn.commit()
|
||||
async with self._sf() as session:
|
||||
row = await session.get(UserRow, str(user.id))
|
||||
if row is None:
|
||||
return user
|
||||
row.email = user.email
|
||||
row.password_hash = user.password_hash
|
||||
row.system_role = user.system_role
|
||||
row.oauth_provider = user.oauth_provider
|
||||
row.oauth_id = user.oauth_id
|
||||
row.needs_setup = user.needs_setup
|
||||
row.token_version = user.token_version
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
async def count_users(self) -> int:
|
||||
"""Return total number of registered users."""
|
||||
return await asyncio.to_thread(self._count_users_sync)
|
||||
|
||||
def _count_users_sync(self) -> int:
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute("SELECT COUNT(*) FROM users")
|
||||
return cursor.fetchone()[0]
|
||||
stmt = select(func.count()).select_from(UserRow)
|
||||
async with self._sf() as session:
|
||||
return await session.scalar(stmt) or 0
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Get user by OAuth provider and ID from SQLite."""
|
||||
return await asyncio.to_thread(self._get_user_by_oauth_sync, provider, oauth_id)
|
||||
|
||||
def _get_user_by_oauth_sync(self, provider: str, oauth_id: str) -> User | None:
|
||||
"""Synchronous get by OAuth (runs in thread pool)."""
|
||||
with _get_users_conn() as conn:
|
||||
cursor = conn.execute(
|
||||
"SELECT * FROM users WHERE oauth_provider = ? AND oauth_id = ?",
|
||||
(provider, oauth_id),
|
||||
)
|
||||
row = cursor.fetchone()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_user(dict(row))
|
||||
|
||||
@staticmethod
|
||||
def _row_to_user(row: dict[str, Any]) -> User:
|
||||
"""Convert a database row to a User model."""
|
||||
return User(
|
||||
id=UUID(row["id"]),
|
||||
email=row["email"],
|
||||
password_hash=row["password_hash"],
|
||||
system_role=row["system_role"],
|
||||
created_at=datetime.fromtimestamp(row["created_at"], tz=UTC),
|
||||
oauth_provider=row.get("oauth_provider"),
|
||||
oauth_id=row.get("oauth_id"),
|
||||
needs_setup=bool(row["needs_setup"]),
|
||||
token_version=int(row["token_version"]),
|
||||
)
|
||||
stmt = select(UserRow).where(UserRow.oauth_provider == provider, UserRow.oauth_id == oauth_id)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
row = result.scalar_one_or_none()
|
||||
return self._row_to_user(row) if row is not None else None
|
||||
|
||||
@@ -142,12 +142,20 @@ _cached_repo: SQLiteUserRepository | None = None
|
||||
|
||||
|
||||
def get_local_provider() -> LocalAuthProvider:
|
||||
"""Get or create the cached LocalAuthProvider singleton."""
|
||||
"""Get or create the cached LocalAuthProvider singleton.
|
||||
|
||||
Must be called after ``init_engine_from_config()`` — the shared
|
||||
session factory is required to construct the user repository.
|
||||
"""
|
||||
global _cached_local_provider, _cached_repo
|
||||
if _cached_repo is None:
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
from deerflow.persistence.engine import get_session_factory
|
||||
|
||||
_cached_repo = SQLiteUserRepository()
|
||||
sf = get_session_factory()
|
||||
if sf is None:
|
||||
raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table")
|
||||
_cached_repo = SQLiteUserRepository(sf)
|
||||
if _cached_local_provider is None:
|
||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||
|
||||
|
||||
Reference in New Issue
Block a user