diff --git a/backend/app/gateway/routers/auth.py b/backend/app/gateway/routers/auth.py index 6192456fb..e57182c26 100644 --- a/backend/app/gateway/routers/auth.py +++ b/backend/app/gateway/routers/auth.py @@ -1,5 +1,6 @@ """Authentication endpoints.""" +import asyncio import logging import os import time @@ -382,9 +383,15 @@ async def get_me(request: Request): return UserResponse(id=str(user.id), email=user.email, system_role=user.system_role, needs_setup=user.needs_setup) -_SETUP_STATUS_COOLDOWN: dict[str, float] = {} -_SETUP_STATUS_COOLDOWN_SECONDS = 60 +# Per-IP cache: ip → (timestamp, result_dict). +# Returns the cached result within the TTL instead of 429, because +# the answer (whether an admin exists) rarely changes and returning +# 429 breaks multi-tab / post-restart reconnection storms. +_SETUP_STATUS_CACHE: dict[str, tuple[float, dict]] = {} +_SETUP_STATUS_CACHE_TTL_SECONDS = 60 _MAX_TRACKED_SETUP_STATUS_IPS = 10000 +_SETUP_STATUS_INFLIGHT: dict[str, asyncio.Task[dict]] = {} +_SETUP_STATUS_INFLIGHT_GUARD = asyncio.Lock() @router.get("/setup-status") @@ -392,29 +399,56 @@ async def setup_status(request: Request): """Check if an admin account exists. Returns needs_setup=True when no admin exists.""" client_ip = _get_client_ip(request) now = time.time() - last_check = _SETUP_STATUS_COOLDOWN.get(client_ip, 0) - elapsed = now - last_check - if elapsed < _SETUP_STATUS_COOLDOWN_SECONDS: - retry_after = max(1, int(_SETUP_STATUS_COOLDOWN_SECONDS - elapsed)) - raise HTTPException( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - detail="Setup status check is rate limited", - headers={"Retry-After": str(retry_after)}, - ) - # Evict stale entries when dict grows too large to bound memory usage. - if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS: - cutoff = now - _SETUP_STATUS_COOLDOWN_SECONDS - stale = [k for k, t in _SETUP_STATUS_COOLDOWN.items() if t < cutoff] - for k in stale: - del _SETUP_STATUS_COOLDOWN[k] - # If still too large after evicting expired entries, remove oldest half. - if len(_SETUP_STATUS_COOLDOWN) >= _MAX_TRACKED_SETUP_STATUS_IPS: - by_time = sorted(_SETUP_STATUS_COOLDOWN.items(), key=lambda kv: kv[1]) - for k, _ in by_time[: len(by_time) // 2]: - del _SETUP_STATUS_COOLDOWN[k] - _SETUP_STATUS_COOLDOWN[client_ip] = now - admin_count = await get_local_provider().count_admin_users() - return {"needs_setup": admin_count == 0} + + # Return cached result when within TTL — avoids 429 on multi-tab reconnection. + cached = _SETUP_STATUS_CACHE.get(client_ip) + if cached is not None: + cached_time, cached_result = cached + if now - cached_time < _SETUP_STATUS_CACHE_TTL_SECONDS: + return cached_result + + async with _SETUP_STATUS_INFLIGHT_GUARD: + # Recheck cache after waiting for the inflight guard. + now = time.time() + cached = _SETUP_STATUS_CACHE.get(client_ip) + if cached is not None: + cached_time, cached_result = cached + if now - cached_time < _SETUP_STATUS_CACHE_TTL_SECONDS: + return cached_result + + task = _SETUP_STATUS_INFLIGHT.get(client_ip) + if task is None: + # Evict stale entries when dict grows too large to bound memory usage. + if len(_SETUP_STATUS_CACHE) >= _MAX_TRACKED_SETUP_STATUS_IPS: + cutoff = now - _SETUP_STATUS_CACHE_TTL_SECONDS + stale = [k for k, (t, _) in _SETUP_STATUS_CACHE.items() if t < cutoff] + for k in stale: + del _SETUP_STATUS_CACHE[k] + if len(_SETUP_STATUS_CACHE) >= _MAX_TRACKED_SETUP_STATUS_IPS: + by_time = sorted(_SETUP_STATUS_CACHE.items(), key=lambda entry: entry[1][0]) + for k, _ in by_time[: len(by_time) // 2]: + del _SETUP_STATUS_CACHE[k] + + async def _compute_setup_status() -> dict: + admin_count = await get_local_provider().count_admin_users() + return {"needs_setup": admin_count == 0} + + task = asyncio.create_task(_compute_setup_status()) + _SETUP_STATUS_INFLIGHT[client_ip] = task + + try: + result = await task + finally: + async with _SETUP_STATUS_INFLIGHT_GUARD: + if _SETUP_STATUS_INFLIGHT.get(client_ip) is task: + del _SETUP_STATUS_INFLIGHT[client_ip] + + # Cache only the stable "initialized" result to avoid stale setup redirects. + if result["needs_setup"] is False: + _SETUP_STATUS_CACHE[client_ip] = (time.time(), result) + else: + _SETUP_STATUS_CACHE.pop(client_ip, None) + return result class InitializeAdminRequest(BaseModel): diff --git a/backend/tests/test_initialize_admin.py b/backend/tests/test_initialize_admin.py index 26b2ec6b2..514ee6df3 100644 --- a/backend/tests/test_initialize_admin.py +++ b/backend/tests/test_initialize_admin.py @@ -22,7 +22,7 @@ _TEST_SECRET = "test-secret-key-initialize-admin-min-32" def _setup_auth(tmp_path): """Fresh SQLite engine + auth config per test.""" from app.gateway import deps - from app.gateway.routers.auth import _SETUP_STATUS_COOLDOWN + from app.gateway.routers.auth import _SETUP_STATUS_CACHE, _SETUP_STATUS_INFLIGHT from deerflow.persistence.engine import close_engine, init_engine set_auth_config(AuthConfig(jwt_secret=_TEST_SECRET)) @@ -30,13 +30,15 @@ def _setup_auth(tmp_path): asyncio.run(init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))) deps._cached_local_provider = None deps._cached_repo = None - _SETUP_STATUS_COOLDOWN.clear() + _SETUP_STATUS_CACHE.clear() + _SETUP_STATUS_INFLIGHT.clear() try: yield finally: deps._cached_local_provider = None deps._cached_repo = None - _SETUP_STATUS_COOLDOWN.clear() + _SETUP_STATUS_CACHE.clear() + _SETUP_STATUS_INFLIGHT.clear() asyncio.run(close_engine()) @@ -168,15 +170,76 @@ def test_setup_status_false_when_only_regular_user_exists(client): assert resp.json()["needs_setup"] is True -def test_setup_status_rate_limited_on_second_call(client): - """Second /setup-status call within the cooldown window returns 429 with Retry-After.""" - # First call succeeds. +def test_setup_status_returns_cached_result_on_rapid_calls(client): + """Rapid /setup-status calls return the cached result (200) instead of 429.""" + client.post("/api/v1/auth/initialize", json=_init_payload()) + + # First call succeeds and computes the result. resp1 = client.get("/api/v1/auth/setup-status") assert resp1.status_code == 200 - # Immediate second call is rate-limited. + # Immediate second call returns cached result, not 429. resp2 = client.get("/api/v1/auth/setup-status") - assert resp2.status_code == 429 - assert "Retry-After" in resp2.headers - retry_after = int(resp2.headers["Retry-After"]) - assert 1 <= retry_after <= 60 + assert resp2.status_code == 200 + assert resp2.json() == resp1.json() + assert resp2.json()["needs_setup"] is False + + +def test_setup_status_does_not_return_stale_true_after_initialize(client): + """A pre-initialize setup-status response should not stay cached as True.""" + before = client.get("/api/v1/auth/setup-status") + assert before.status_code == 200 + assert before.json()["needs_setup"] is True + + init = client.post("/api/v1/auth/initialize", json=_init_payload()) + assert init.status_code == 201 + + after = client.get("/api/v1/auth/setup-status") + assert after.status_code == 200 + assert after.json()["needs_setup"] is False + + +@pytest.mark.asyncio +async def test_setup_status_single_flight_per_ip(monkeypatch): + """Concurrent requests from same IP share one in-flight DB query.""" + from starlette.requests import Request + + from app.gateway.routers.auth import ( + _SETUP_STATUS_CACHE, + _SETUP_STATUS_INFLIGHT, + setup_status, + ) + + class _Provider: + def __init__(self): + self.calls = 0 + + async def count_admin_users(self): + self.calls += 1 + await asyncio.sleep(0.05) + return 0 + + provider = _Provider() + monkeypatch.setattr("app.gateway.routers.auth.get_local_provider", lambda: provider) + _SETUP_STATUS_CACHE.clear() + _SETUP_STATUS_INFLIGHT.clear() + + def _request() -> Request: + return Request( + { + "type": "http", + "method": "GET", + "path": "/api/v1/auth/setup-status", + "headers": [], + "client": ("127.0.0.1", 12345), + } + ) + + results = await asyncio.gather( + setup_status(_request()), + setup_status(_request()), + setup_status(_request()), + ) + + assert all(result["needs_setup"] is True for result in results) + assert provider.calls == 1