mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
refactor(auth): migrate user repository to SQLAlchemy ORM
Move the users table into the shared persistence engine so auth matches the pattern of threads_meta, runs, run_events, and feedback — one engine, one session factory, one schema init codepath. New files --------- - persistence/user/__init__.py, persistence/user/model.py: UserRow ORM class with partial unique index on (oauth_provider, oauth_id) - Registered in persistence/models/__init__.py so Base.metadata.create_all() picks it up Modified -------- - auth/repositories/sqlite.py: rewritten as async SQLAlchemy, identical constructor pattern to the other four repositories (def __init__(self, session_factory) + self._sf = session_factory) - auth/config.py: drop users_db_path field — storage is configured through config.database like every other table - deps.py/get_local_provider: construct SQLiteUserRepository with the shared session factory, fail fast if engine is not initialised - tests/test_auth.py: rewrite test_sqlite_round_trip_new_fields to use the shared engine (init_engine + close_engine in a tempdir) - tests/test_auth_type_system.py: add per-test autouse fixture that spins up a scratch engine and resets deps._cached_* singletons
This commit is contained in:
+44
-35
@@ -262,47 +262,56 @@ def test_user_model_needs_setup_true():
|
||||
|
||||
|
||||
def test_sqlite_round_trip_new_fields():
|
||||
"""needs_setup and token_version survive create → read round-trip."""
|
||||
"""needs_setup and token_version survive create → read round-trip.
|
||||
|
||||
Uses the shared persistence engine (same one threads_meta, runs,
|
||||
run_events, and feedback use). The old separate .deer-flow/users.db
|
||||
file is gone.
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
from app.gateway.auth.repositories import sqlite as sqlite_mod
|
||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = os.path.join(tmpdir, "test_users.db")
|
||||
old_path = sqlite_mod._resolved_db_path
|
||||
old_init = sqlite_mod._table_initialized
|
||||
sqlite_mod._resolved_db_path = Path(db_path)
|
||||
sqlite_mod._table_initialized = False
|
||||
try:
|
||||
repo = sqlite_mod.SQLiteUserRepository()
|
||||
user = User(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
created = asyncio.run(repo.create_user(user))
|
||||
assert created.needs_setup is True
|
||||
assert created.token_version == 3
|
||||
async def _run() -> None:
|
||||
from deerflow.persistence.engine import (
|
||||
close_engine,
|
||||
get_session_factory,
|
||||
init_engine,
|
||||
)
|
||||
|
||||
fetched = asyncio.run(repo.get_user_by_email("setup@test.com"))
|
||||
assert fetched is not None
|
||||
assert fetched.needs_setup is True
|
||||
assert fetched.token_version == 3
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
url = f"sqlite+aiosqlite:///{tmpdir}/scratch.db"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=tmpdir)
|
||||
try:
|
||||
repo = SQLiteUserRepository(get_session_factory())
|
||||
user = User(
|
||||
email="setup@test.com",
|
||||
password_hash="fakehash",
|
||||
system_role="admin",
|
||||
needs_setup=True,
|
||||
token_version=3,
|
||||
)
|
||||
created = await repo.create_user(user)
|
||||
assert created.needs_setup is True
|
||||
assert created.token_version == 3
|
||||
|
||||
fetched.needs_setup = False
|
||||
fetched.token_version = 4
|
||||
asyncio.run(repo.update_user(fetched))
|
||||
refetched = asyncio.run(repo.get_user_by_id(str(fetched.id)))
|
||||
assert refetched.needs_setup is False
|
||||
assert refetched.token_version == 4
|
||||
finally:
|
||||
sqlite_mod._resolved_db_path = old_path
|
||||
sqlite_mod._table_initialized = old_init
|
||||
fetched = await repo.get_user_by_email("setup@test.com")
|
||||
assert fetched is not None
|
||||
assert fetched.needs_setup is True
|
||||
assert fetched.token_version == 3
|
||||
|
||||
fetched.needs_setup = False
|
||||
fetched.token_version = 4
|
||||
await repo.update_user(fetched)
|
||||
refetched = await repo.get_user_by_id(str(fetched.id))
|
||||
assert refetched is not None
|
||||
assert refetched.needs_setup is False
|
||||
assert refetched.token_version == 4
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ── Token Versioning ───────────────────────────────────────────────────────
|
||||
|
||||
@@ -32,6 +32,32 @@ from app.gateway.csrf_middleware import (
|
||||
_TEST_SECRET = "test-secret-for-auth-type-system-tests-min32"
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _persistence_engine(tmp_path):
|
||||
"""Initialise a per-test SQLite engine + reset cached provider singletons.
|
||||
|
||||
The auth tests call real HTTP handlers that go through
|
||||
``SQLiteUserRepository`` → ``get_session_factory``. Each test gets
|
||||
a fresh DB plus a clean ``deps._cached_*`` so the cached provider
|
||||
does not hold a dangling reference to the previous test's engine.
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from app.gateway import deps
|
||||
from deerflow.persistence.engine import close_engine, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path}/auth_types.db"
|
||||
asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path)))
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
deps._cached_local_provider = None
|
||||
deps._cached_repo = None
|
||||
asyncio.run(close_engine())
|
||||
|
||||
|
||||
def _setup_config():
|
||||
set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user