178 lines
6.1 KiB
Python
178 lines
6.1 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
from collections.abc import AsyncGenerator
|
|
from contextlib import asynccontextmanager
|
|
from pathlib import Path
|
|
from uuid import uuid4
|
|
|
|
import pytest
|
|
|
|
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
from store.repositories import UserCreate, UserNotFoundError, build_user_repository
|
|
from store.repositories.models import User as UserModel
|
|
|
|
|
|
@asynccontextmanager
|
|
async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]:
|
|
db_path = tmp_path / "storage-users.db"
|
|
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(UserModel.metadata.create_all)
|
|
|
|
try:
|
|
yield async_sessionmaker(engine, expire_on_commit=False)
|
|
finally:
|
|
await engine.dispose()
|
|
|
|
|
|
async def _create_user(
|
|
session_factory: async_sessionmaker[AsyncSession],
|
|
*,
|
|
email: str = "user@example.com",
|
|
system_role: str = "user",
|
|
oauth_provider: str | None = None,
|
|
oauth_id: str | None = None,
|
|
):
|
|
async with session_factory() as session:
|
|
repo = build_user_repository(session)
|
|
user = await repo.create_user(
|
|
UserCreate(
|
|
id=str(uuid4()),
|
|
email=email,
|
|
password_hash="hash",
|
|
system_role=system_role, # type: ignore[arg-type]
|
|
oauth_provider=oauth_provider,
|
|
oauth_id=oauth_id,
|
|
)
|
|
)
|
|
await session.commit()
|
|
return user
|
|
|
|
|
|
def test_create_and_get_user_by_id_and_email(tmp_path):
|
|
async def run() -> None:
|
|
async with _session_factory(tmp_path) as session_factory:
|
|
created = await _create_user(session_factory)
|
|
|
|
async with session_factory() as session:
|
|
repo = build_user_repository(session)
|
|
|
|
by_id = await repo.get_user_by_id(created.id)
|
|
by_email = await repo.get_user_by_email(created.email)
|
|
|
|
assert by_id == created
|
|
assert by_email == created
|
|
assert created.system_role == "user"
|
|
assert created.needs_setup is False
|
|
assert created.token_version == 0
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_duplicate_email_raises_value_error(tmp_path):
|
|
async def run() -> None:
|
|
async with _session_factory(tmp_path) as session_factory:
|
|
await _create_user(session_factory, email="dupe@example.com")
|
|
|
|
async with session_factory() as session:
|
|
repo = build_user_repository(session)
|
|
with pytest.raises(ValueError, match="Email already registered"):
|
|
await repo.create_user(
|
|
UserCreate(
|
|
id=str(uuid4()),
|
|
email="dupe@example.com",
|
|
password_hash="hash",
|
|
)
|
|
)
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_oauth_lookup_and_plain_users_without_oauth(tmp_path):
|
|
async def run() -> None:
|
|
async with _session_factory(tmp_path) as session_factory:
|
|
await _create_user(session_factory, email="local-1@example.com")
|
|
await _create_user(session_factory, email="local-2@example.com")
|
|
oauth_user = await _create_user(
|
|
session_factory,
|
|
email="oauth@example.com",
|
|
oauth_provider="github",
|
|
oauth_id="gh-123",
|
|
)
|
|
|
|
async with session_factory() as session:
|
|
repo = build_user_repository(session)
|
|
|
|
assert await repo.count_users() == 3
|
|
assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user
|
|
assert await repo.get_user_by_oauth("github", "missing") is None
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_count_admins_and_get_first_admin(tmp_path):
|
|
async def run() -> None:
|
|
async with _session_factory(tmp_path) as session_factory:
|
|
await _create_user(session_factory, email="user@example.com")
|
|
admin = await _create_user(
|
|
session_factory,
|
|
email="admin@example.com",
|
|
system_role="admin",
|
|
)
|
|
|
|
async with session_factory() as session:
|
|
repo = build_user_repository(session)
|
|
|
|
assert await repo.count_users() == 2
|
|
assert await repo.count_admin_users() == 1
|
|
assert await repo.get_first_admin() == admin
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_update_user_round_trips_token_version_and_setup_state(tmp_path):
|
|
async def run() -> None:
|
|
async with _session_factory(tmp_path) as session_factory:
|
|
created = await _create_user(session_factory)
|
|
updated = created.model_copy(
|
|
update={
|
|
"email": "renamed@example.com",
|
|
"token_version": 4,
|
|
"needs_setup": True,
|
|
}
|
|
)
|
|
|
|
async with session_factory() as session:
|
|
repo = build_user_repository(session)
|
|
saved = await repo.update_user(updated)
|
|
await session.commit()
|
|
|
|
async with session_factory() as session:
|
|
repo = build_user_repository(session)
|
|
fetched = await repo.get_user_by_id(created.id)
|
|
|
|
assert saved.email == "renamed@example.com"
|
|
assert fetched == updated
|
|
|
|
asyncio.run(run())
|
|
|
|
|
|
def test_update_missing_user_raises(tmp_path):
|
|
async def run() -> None:
|
|
async with _session_factory(tmp_path) as session_factory:
|
|
missing = UserCreate(id=str(uuid4()), email="missing@example.com")
|
|
|
|
async with session_factory() as session:
|
|
repo = build_user_repository(session)
|
|
created_shape = await repo.create_user(missing)
|
|
await session.rollback()
|
|
|
|
with pytest.raises(UserNotFoundError):
|
|
await repo.update_user(created_shape)
|
|
|
|
asyncio.run(run())
|