"""Centralized accessors for singleton objects stored on ``app.state``. **Getters** (used by routers): raise 503 when a required dependency is missing, except ``get_store`` which returns ``None``. Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. """ from __future__ import annotations from collections.abc import AsyncGenerator, Callable from contextlib import AsyncExitStack, asynccontextmanager from typing import TYPE_CHECKING, TypeVar, cast from fastapi import FastAPI, HTTPException, Request from langgraph.types import Checkpointer from deerflow.persistence.feedback import FeedbackRepository from deerflow.runtime import RunContext, RunManager, StreamBridge from deerflow.runtime.events.store.base import RunEventStore from deerflow.runtime.runs.store.base import RunStore if TYPE_CHECKING: from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from deerflow.persistence.thread_meta.base import ThreadMetaStore T = TypeVar("T") @asynccontextmanager async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]: """Bootstrap and tear down all LangGraph runtime singletons. Usage in ``app.py``:: async with langgraph_runtime(app): yield """ from deerflow.config import get_app_config from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config from deerflow.runtime import make_store, make_stream_bridge from deerflow.runtime.checkpointer.async_provider import make_checkpointer from deerflow.runtime.events.store import make_run_event_store async with AsyncExitStack() as stack: app.state.stream_bridge = await stack.enter_async_context(make_stream_bridge()) # Initialize persistence engine BEFORE checkpointer so that # auto-create-database logic runs first (postgres backend). config = get_app_config() await init_engine_from_config(config.database) app.state.checkpointer = await stack.enter_async_context(make_checkpointer()) app.state.store = await stack.enter_async_context(make_store()) # Initialize repositories — one get_session_factory() call for all. sf = get_session_factory() if sf is not None: from deerflow.persistence.feedback import FeedbackRepository from deerflow.persistence.run import RunRepository app.state.run_store = RunRepository(sf) app.state.feedback_repo = FeedbackRepository(sf) else: from deerflow.runtime.runs.store.memory import MemoryRunStore app.state.run_store = MemoryRunStore() app.state.feedback_repo = None from deerflow.persistence.thread_meta import make_thread_store app.state.thread_store = make_thread_store(sf, app.state.store) # Run event store (has its own factory with config-driven backend selection) run_events_config = getattr(config, "run_events", None) app.state.run_event_store = make_run_event_store(run_events_config) # RunManager with store backing for persistence app.state.run_manager = RunManager(store=app.state.run_store) try: yield finally: await close_engine() # --------------------------------------------------------------------------- # Getters – called by routers per-request # --------------------------------------------------------------------------- def _require(attr: str, label: str) -> Callable[[Request], T]: """Create a FastAPI dependency that returns ``app.state.`` or 503.""" def dep(request: Request) -> T: val = getattr(request.app.state, attr, None) if val is None: raise HTTPException(status_code=503, detail=f"{label} not available") return cast(T, val) dep.__name__ = dep.__qualname__ = f"get_{attr}" return dep get_stream_bridge: Callable[[Request], StreamBridge] = _require("stream_bridge", "Stream bridge") get_run_manager: Callable[[Request], RunManager] = _require("run_manager", "Run manager") get_checkpointer: Callable[[Request], Checkpointer] = _require("checkpointer", "Checkpointer") get_run_event_store: Callable[[Request], RunEventStore] = _require("run_event_store", "Run event store") get_feedback_repo: Callable[[Request], FeedbackRepository] = _require("feedback_repo", "Feedback") get_run_store: Callable[[Request], RunStore] = _require("run_store", "Run store") def get_store(request: Request): """Return the global store (may be ``None`` if not configured).""" return getattr(request.app.state, "store", None) def get_thread_store(request: Request) -> ThreadMetaStore: """Return the thread metadata store (SQL or memory-backed).""" val = getattr(request.app.state, "thread_store", None) if val is None: raise HTTPException(status_code=503, detail="Thread metadata store not available") return val def get_run_context(request: Request) -> RunContext: """Build a :class:`RunContext` from ``app.state`` singletons. Returns a *base* context with infrastructure dependencies. """ from deerflow.config import get_app_config return RunContext( checkpointer=get_checkpointer(request), store=get_store(request), event_store=get_run_event_store(request), run_events_config=getattr(get_app_config(), "run_events", None), thread_store=get_thread_store(request), ) # --------------------------------------------------------------------------- # Auth helpers (used by authz.py and auth middleware) # --------------------------------------------------------------------------- # Cached singletons to avoid repeated instantiation per request _cached_local_provider: LocalAuthProvider | None = None _cached_repo: SQLiteUserRepository | None = None def get_local_provider() -> LocalAuthProvider: """Get or create the cached LocalAuthProvider singleton. Must be called after ``init_engine_from_config()`` — the shared session factory is required to construct the user repository. """ global _cached_local_provider, _cached_repo if _cached_repo is None: from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from deerflow.persistence.engine import get_session_factory sf = get_session_factory() if sf is None: raise RuntimeError("get_local_provider() called before init_engine_from_config(); cannot access users table") _cached_repo = SQLiteUserRepository(sf) if _cached_local_provider is None: from app.gateway.auth.local_provider import LocalAuthProvider _cached_local_provider = LocalAuthProvider(repository=_cached_repo) return _cached_local_provider async def get_current_user_from_request(request: Request): """Get the current authenticated user from the request cookie. Raises HTTPException 401 if not authenticated. """ from app.gateway.auth import decode_token from app.gateway.auth.errors import AuthErrorCode, AuthErrorResponse, TokenError, token_error_to_code access_token = request.cookies.get("access_token") if not access_token: raise HTTPException( status_code=401, detail=AuthErrorResponse(code=AuthErrorCode.NOT_AUTHENTICATED, message="Not authenticated").model_dump(), ) payload = decode_token(access_token) if isinstance(payload, TokenError): raise HTTPException( status_code=401, detail=AuthErrorResponse(code=token_error_to_code(payload), message=f"Token error: {payload.value}").model_dump(), ) provider = get_local_provider() user = await provider.get_user(payload.sub) if user is None: raise HTTPException( status_code=401, detail=AuthErrorResponse(code=AuthErrorCode.USER_NOT_FOUND, message="User not found").model_dump(), ) # Token version mismatch → password was changed, token is stale if user.token_version != payload.ver: raise HTTPException( status_code=401, detail=AuthErrorResponse(code=AuthErrorCode.TOKEN_INVALID, message="Token revoked (password changed)").model_dump(), ) return user async def get_optional_user_from_request(request: Request): """Get optional authenticated user from request. Returns None if not authenticated. """ try: return await get_current_user_from_request(request) except HTTPException: return None async def get_current_user(request: Request) -> str | None: """Extract user_id from request cookie, or None if not authenticated. Thin adapter that returns the string id for callers that only need identification (e.g., ``feedback.py``). Full-user callers should use ``get_current_user_from_request`` or ``get_optional_user_from_request``. """ user = await get_optional_user_from_request(request) return str(user.id) if user else None