refactor(persistence): rename owner_id to user_id and thread_meta_repo to thread_store

Rename owner_id to user_id across all persistence models, repositories,
stores, routers, and tests for clearer semantics. Rename thread_meta_repo
to thread_store for consistency with run_store/run_event_store naming.
Add ThreadMetaStore return type annotation to get_thread_store().

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-10 15:05:10 +08:00
parent 03952eca53
commit 8da1903168
32 changed files with 256 additions and 276 deletions
+11 -4
View File
@@ -42,6 +42,11 @@ logger = logging.getLogger(__name__)
async def _ensure_admin_user(app: FastAPI) -> None: async def _ensure_admin_user(app: FastAPI) -> None:
"""Startup hook: generate init token on first boot; migrate orphan threads otherwise. """Startup hook: generate init token on first boot; migrate orphan threads otherwise.
After admin creation, migrate orphan threads from the LangGraph
store (metadata.user_id unset) to the admin account. This is the
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
authentication have existing LangGraph thread data that needs an
owner assigned.
First boot (no admin exists): First boot (no admin exists):
- Generates a one-time ``init_token`` stored in ``app.state.init_token`` - Generates a one-time ``init_token`` stored in ``app.state.init_token``
- Logs the token to stdout so the operator can copy-paste it into the - Logs the token to stdout so the operator can copy-paste it into the
@@ -52,7 +57,7 @@ async def _ensure_admin_user(app: FastAPI) -> None:
- Runs the one-time "no-auth → with-auth" orphan thread migration for - Runs the one-time "no-auth → with-auth" orphan thread migration for
existing LangGraph thread metadata that has no owner_id. existing LangGraph thread metadata that has no owner_id.
No SQL persistence migration is needed: the four owner_id columns No SQL persistence migration is needed: the four user_id columns
(threads_meta, runs, run_events, feedback) only come into existence (threads_meta, runs, run_events, feedback) only come into existence
alongside the auth module via create_all, so freshly created tables alongside the auth module via create_all, so freshly created tables
never contain NULL-owner rows. never contain NULL-owner rows.
@@ -96,6 +101,8 @@ async def _ensure_admin_user(app: FastAPI) -> None:
admin_id = str(row.id) admin_id = str(row.id)
# LangGraph store orphan migration — non-fatal. # LangGraph store orphan migration — non-fatal.
# This covers the "no-auth → with-auth" upgrade path for users
# whose existing LangGraph thread metadata has no user_id set.
store = getattr(app.state, "store", None) store = getattr(app.state, "store", None)
if store is not None: if store is not None:
try: try:
@@ -127,7 +134,7 @@ async def _iter_store_items(store, namespace, *, page_size: int = 500):
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int: async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
"""Migrate LangGraph store threads with no owner_id to the given admin. """Migrate LangGraph store threads with no user_id to the given admin.
Uses cursor pagination so all orphans are migrated regardless of Uses cursor pagination so all orphans are migrated regardless of
count. Returns the number of rows migrated. count. Returns the number of rows migrated.
@@ -135,8 +142,8 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
migrated = 0 migrated = 0
async for item in _iter_store_items(store, ("threads",)): async for item in _iter_store_items(store, ("threads",)):
metadata = item.value.get("metadata", {}) metadata = item.value.get("metadata", {})
if not metadata.get("owner_id"): if not metadata.get("user_id"):
metadata["owner_id"] = admin_user_id metadata["user_id"] = admin_user_id
item.value["metadata"] = metadata item.value["metadata"] = metadata
await store.aput(("threads",), item.key, item.value) await store.aput(("threads",), item.key, item.value)
migrated += 1 migrated += 1
+5 -5
View File
@@ -233,18 +233,18 @@ def require_permission(
# (``threads_meta`` table). We verify ownership via # (``threads_meta`` table). We verify ownership via
# ``ThreadMetaStore.check_access``: it returns True for # ``ThreadMetaStore.check_access``: it returns True for
# missing rows (untracked legacy thread) and for rows whose # missing rows (untracked legacy thread) and for rows whose
# ``owner_id`` is NULL (shared / pre-auth data), so this is # ``user_id`` is NULL (shared / pre-auth data), so this is
# strict-deny rather than strict-allow — only an *existing* # strict-deny rather than strict-allow — only an *existing*
# row with a *different* owner_id triggers 404. # row with a *different* user_id triggers 404.
if owner_check: if owner_check:
thread_id = kwargs.get("thread_id") thread_id = kwargs.get("thread_id")
if thread_id is None: if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
from app.gateway.deps import get_thread_meta_repo from app.gateway.deps import get_thread_store
thread_meta_repo = get_thread_meta_repo(request) thread_store = get_thread_store(request)
allowed = await thread_meta_repo.check_access( allowed = await thread_store.check_access(
thread_id, thread_id,
str(auth.user.id), str(auth.user.id),
require_existing=require_existing, require_existing=require_existing,
+14 -9
View File
@@ -1,8 +1,7 @@
"""Centralized accessors for singleton objects stored on ``app.state``. """Centralized accessors for singleton objects stored on ``app.state``.
**Getters** (used by routers): raise 503 when a required dependency is **Getters** (used by routers): raise 503 when a required dependency is
missing, except ``get_store`` and ``get_thread_meta_repo`` which return missing, except ``get_store`` which returns ``None``.
``None``.
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`. Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
""" """
@@ -20,6 +19,7 @@ from deerflow.runtime import RunContext, RunManager
if TYPE_CHECKING: if TYPE_CHECKING:
from app.gateway.auth.local_provider import LocalAuthProvider from app.gateway.auth.local_provider import LocalAuthProvider
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
from deerflow.persistence.thread_meta.base import ThreadMetaStore
@asynccontextmanager @asynccontextmanager
@@ -31,10 +31,10 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
async with langgraph_runtime(app): async with langgraph_runtime(app):
yield yield
""" """
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
from deerflow.config import get_app_config from deerflow.config import get_app_config
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
from deerflow.runtime import make_store, make_stream_bridge from deerflow.runtime import make_store, make_stream_bridge
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
from deerflow.runtime.events.store import make_run_event_store from deerflow.runtime.events.store import make_run_event_store
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
@@ -53,18 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
if sf is not None: if sf is not None:
from deerflow.persistence.feedback import FeedbackRepository from deerflow.persistence.feedback import FeedbackRepository
from deerflow.persistence.run import RunRepository from deerflow.persistence.run import RunRepository
from deerflow.persistence.thread_meta import ThreadMetaRepository
app.state.run_store = RunRepository(sf) app.state.run_store = RunRepository(sf)
app.state.feedback_repo = FeedbackRepository(sf) app.state.feedback_repo = FeedbackRepository(sf)
app.state.thread_meta_repo = ThreadMetaRepository(sf)
else: else:
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
from deerflow.runtime.runs.store.memory import MemoryRunStore from deerflow.runtime.runs.store.memory import MemoryRunStore
app.state.run_store = MemoryRunStore() app.state.run_store = MemoryRunStore()
app.state.feedback_repo = None app.state.feedback_repo = None
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
from deerflow.persistence.thread_meta import make_thread_store
app.state.thread_store = make_thread_store(sf, app.state.store)
# Run event store (has its own factory with config-driven backend selection) # Run event store (has its own factory with config-driven backend selection)
run_events_config = getattr(config, "run_events", None) run_events_config = getattr(config, "run_events", None)
@@ -110,7 +110,12 @@ def get_store(request: Request):
return getattr(request.app.state, "store", None) return getattr(request.app.state, "store", None)
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store") def get_thread_store(request: Request) -> ThreadMetaStore:
"""Return the thread metadata store (SQL or memory-backed)."""
val = getattr(request.app.state, "thread_store", None)
if val is None:
raise HTTPException(status_code=503, detail="Thread metadata store not available")
return val
def get_run_context(request: Request) -> RunContext: def get_run_context(request: Request) -> RunContext:
@@ -128,7 +133,7 @@ def get_run_context(request: Request) -> RunContext:
store=get_store(request), store=get_store(request),
event_store=get_run_event_store(request), event_store=get_run_event_store(request),
run_events_config=getattr(get_app_config(), "run_events", None), run_events_config=getattr(get_app_config(), "run_events", None),
thread_meta_repo=get_thread_meta_repo(request), thread_store=get_thread_store(request),
) )
+5 -5
View File
@@ -93,14 +93,14 @@ async def authenticate(request):
@auth.on @auth.on
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict): async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
"""Inject owner_id metadata on writes; filter by owner_id on reads. """Inject user_id metadata on writes; filter by user_id on reads.
Gateway stores thread ownership as ``metadata.owner_id``. Gateway stores thread ownership as ``metadata.user_id``.
This handler ensures LangGraph Server enforces the same isolation. This handler ensures LangGraph Server enforces the same isolation.
""" """
# On create/update: stamp owner_id into metadata # On create/update: stamp user_id into metadata
metadata = value.setdefault("metadata", {}) metadata = value.setdefault("metadata", {})
metadata["owner_id"] = ctx.user.identity metadata["user_id"] = ctx.user.identity
# Return filter dict — LangGraph applies it to search/read/delete # Return filter dict — LangGraph applies it to search/read/delete
return {"owner_id": ctx.user.identity} return {"user_id": ctx.user.identity}
+2 -2
View File
@@ -34,7 +34,7 @@ class FeedbackResponse(BaseModel):
feedback_id: str feedback_id: str
run_id: str run_id: str
thread_id: str thread_id: str
owner_id: str | None = None user_id: str | None = None
message_id: str | None = None message_id: str | None = None
rating: int rating: int
comment: str | None = None comment: str | None = None
@@ -80,7 +80,7 @@ async def create_feedback(
run_id=run_id, run_id=run_id,
thread_id=thread_id, thread_id=thread_id,
rating=body.rating, rating=body.rating,
owner_id=user_id, user_id=user_id,
message_id=body.message_id, message_id=body.message_id,
comment=body.comment, comment=body.comment,
) )
+21 -21
View File
@@ -34,7 +34,7 @@ router = APIRouter(prefix="/api/threads", tags=["threads"])
# them. Pydantic ``@field_validator("metadata")`` strips them on every # them. Pydantic ``@field_validator("metadata")`` strips them on every
# inbound model below so a malicious client cannot reflect a forged # inbound model below so a malicious client cannot reflect a forged
# owner identity through the API surface. Defense-in-depth — the # owner identity through the API surface. Defense-in-depth — the
# row-level invariant is still ``threads_meta.owner_id`` populated from # row-level invariant is still ``threads_meta.user_id`` populated from
# the auth contextvar; this list closes the metadata-blob echo gap. # the auth contextvar; this list closes the metadata-blob echo gap.
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"}) _SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
@@ -194,7 +194,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
and removes the thread_meta row from the configured ThreadMetaStore and removes the thread_meta row from the configured ThreadMetaStore
(sqlite or memory). (sqlite or memory).
""" """
from app.gateway.deps import get_thread_meta_repo from app.gateway.deps import get_thread_store
# Clean local filesystem # Clean local filesystem
response = _delete_thread_data(thread_id) response = _delete_thread_data(thread_id)
@@ -211,8 +211,8 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
# Remove thread_meta row (best-effort) — required for sqlite backend # Remove thread_meta row (best-effort) — required for sqlite backend
# so the deleted thread no longer appears in /threads/search. # so the deleted thread no longer appears in /threads/search.
try: try:
thread_meta_repo = get_thread_meta_repo(request) thread_store = get_thread_store(request)
await thread_meta_repo.delete(thread_id) await thread_store.delete(thread_id)
except Exception: except Exception:
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id)) logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
@@ -227,17 +227,17 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
and an empty checkpoint (so state endpoints work immediately). and an empty checkpoint (so state endpoints work immediately).
Idempotent: returns the existing record when ``thread_id`` already exists. Idempotent: returns the existing record when ``thread_id`` already exists.
""" """
from app.gateway.deps import get_thread_meta_repo from app.gateway.deps import get_thread_store
checkpointer = get_checkpointer(request) checkpointer = get_checkpointer(request)
thread_meta_repo = get_thread_meta_repo(request) thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4()) thread_id = body.thread_id or str(uuid.uuid4())
now = time.time() now = time.time()
# ``body.metadata`` is already stripped of server-reserved keys by # ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition. # ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record when already present # Idempotency: return existing record when already present
existing_record = await thread_meta_repo.get(thread_id) existing_record = await thread_store.get(thread_id)
if existing_record is not None: if existing_record is not None:
return ThreadResponse( return ThreadResponse(
thread_id=thread_id, thread_id=thread_id,
@@ -249,7 +249,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
# Write thread_meta so the thread appears in /threads/search immediately # Write thread_meta so the thread appears in /threads/search immediately
try: try:
await thread_meta_repo.create( await thread_store.create(
thread_id, thread_id,
assistant_id=getattr(body, "assistant_id", None), assistant_id=getattr(body, "assistant_id", None),
metadata=body.metadata, metadata=body.metadata,
@@ -293,9 +293,9 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
Delegates to the configured ThreadMetaStore implementation Delegates to the configured ThreadMetaStore implementation
(SQL-backed for sqlite/postgres, Store-backed for memory mode). (SQL-backed for sqlite/postgres, Store-backed for memory mode).
""" """
from app.gateway.deps import get_thread_meta_repo from app.gateway.deps import get_thread_store
repo = get_thread_meta_repo(request) repo = get_thread_store(request)
rows = await repo.search( rows = await repo.search(
metadata=body.metadata or None, metadata=body.metadata or None,
status=body.status, status=body.status,
@@ -320,22 +320,22 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
@require_permission("threads", "write", owner_check=True, require_existing=True) @require_permission("threads", "write", owner_check=True, require_existing=True)
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse: async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
"""Merge metadata into a thread record.""" """Merge metadata into a thread record."""
from app.gateway.deps import get_thread_meta_repo from app.gateway.deps import get_thread_store
thread_meta_repo = get_thread_meta_repo(request) thread_store = get_thread_store(request)
record = await thread_meta_repo.get(thread_id) record = await thread_store.get(thread_id)
if record is None: if record is None:
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found") raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``. # ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
try: try:
await thread_meta_repo.update_metadata(thread_id, body.metadata) await thread_store.update_metadata(thread_id, body.metadata)
except Exception: except Exception:
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id)) logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread") raise HTTPException(status_code=500, detail="Failed to update thread")
# Re-read to get the merged metadata + refreshed updated_at # Re-read to get the merged metadata + refreshed updated_at
record = await thread_meta_repo.get(thread_id) or record record = await thread_store.get(thread_id) or record
return ThreadResponse( return ThreadResponse(
thread_id=thread_id, thread_id=thread_id,
status=record.get("status", "idle"), status=record.get("status", "idle"),
@@ -354,12 +354,12 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
execution status from the checkpointer. Falls back to the checkpointer execution status from the checkpointer. Falls back to the checkpointer
alone for threads that pre-date ThreadMetaStore adoption (backward compat). alone for threads that pre-date ThreadMetaStore adoption (backward compat).
""" """
from app.gateway.deps import get_thread_meta_repo from app.gateway.deps import get_thread_store
thread_meta_repo = get_thread_meta_repo(request) thread_store = get_thread_store(request)
checkpointer = get_checkpointer(request) checkpointer = get_checkpointer(request)
record: dict | None = await thread_meta_repo.get(thread_id) record: dict | None = await thread_store.get(thread_id)
# Derive accurate status from the checkpointer # Derive accurate status from the checkpointer
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}} config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
@@ -462,10 +462,10 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
ThreadMetaStore abstraction so that ``/threads/search`` reflects the ThreadMetaStore abstraction so that ``/threads/search`` reflects the
change immediately in both sqlite and memory backends. change immediately in both sqlite and memory backends.
""" """
from app.gateway.deps import get_thread_meta_repo from app.gateway.deps import get_thread_store
checkpointer = get_checkpointer(request) checkpointer = get_checkpointer(request)
thread_meta_repo = get_thread_meta_repo(request) thread_store = get_thread_store(request)
# checkpoint_ns must be present in the config for aput — default to "" # checkpoint_ns must be present in the config for aput — default to ""
# (the root graph namespace). checkpoint_id is optional; omitting it # (the root graph namespace). checkpoint_id is optional; omitting it
@@ -529,7 +529,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
new_title = body.values["title"] new_title = body.values["title"]
if new_title: # Skip empty strings and None if new_title: # Skip empty strings and None
try: try:
await thread_meta_repo.update_display_name(thread_id, new_title) await thread_store.update_display_name(thread_id, new_title)
except Exception: except Exception:
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
+4 -4
View File
@@ -229,15 +229,15 @@ async def start_run(
# even for threads that were never explicitly created via POST /threads # even for threads that were never explicitly created via POST /threads
# (e.g. stateless runs). # (e.g. stateless runs).
try: try:
existing = await run_ctx.thread_meta_repo.get(thread_id) existing = await run_ctx.thread_store.get(thread_id)
if existing is None: if existing is None:
await run_ctx.thread_meta_repo.create( await run_ctx.thread_store.create(
thread_id, thread_id,
assistant_id=body.assistant_id, assistant_id=body.assistant_id,
metadata=body.metadata, metadata=body.metadata,
) )
else: else:
await run_ctx.thread_meta_repo.update_status(thread_id, "running") await run_ctx.thread_store.update_status(thread_id, "running")
except Exception: except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
@@ -285,7 +285,7 @@ async def start_run(
record.task = task record.task = task
# Title sync is handled by worker.py's finally block which reads the # Title sync is handled by worker.py's finally block which reads the
# title from the checkpoint and calls thread_meta_repo.update_display_name # title from the checkpoint and calls thread_store.update_display_name
# after the run completes. # after the run completes.
return record return record
@@ -16,7 +16,7 @@ class FeedbackRow(Base):
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True) feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
owner_id: Mapped[str | None] = mapped_column(String(64), index=True) user_id: Mapped[str | None] = mapped_column(String(64), index=True)
message_id: Mapped[str | None] = mapped_column(String(64)) message_id: Mapped[str | None] = mapped_column(String(64))
# message_id is an optional RunEventStore event identifier — # message_id is an optional RunEventStore event identifier —
# allows feedback to target a specific message or the entire run # allows feedback to target a specific message or the entire run
@@ -12,7 +12,7 @@ from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.feedback.model import FeedbackRow from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
class FeedbackRepository: class FeedbackRepository:
@@ -33,19 +33,19 @@ class FeedbackRepository:
run_id: str, run_id: str,
thread_id: str, thread_id: str,
rating: int, rating: int,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
message_id: str | None = None, message_id: str | None = None,
comment: str | None = None, comment: str | None = None,
) -> dict: ) -> dict:
"""Create a feedback record. rating must be +1 or -1.""" """Create a feedback record. rating must be +1 or -1."""
if rating not in (1, -1): if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}") raise ValueError(f"rating must be +1 or -1, got {rating}")
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create") resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
row = FeedbackRow( row = FeedbackRow(
feedback_id=str(uuid.uuid4()), feedback_id=str(uuid.uuid4()),
run_id=run_id, run_id=run_id,
thread_id=thread_id, thread_id=thread_id,
owner_id=resolved_owner_id, user_id=resolved_user_id,
message_id=message_id, message_id=message_id,
rating=rating, rating=rating,
comment=comment, comment=comment,
@@ -61,14 +61,14 @@ class FeedbackRepository:
self, self,
feedback_id: str, feedback_id: str,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> dict | None: ) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get") resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
async with self._sf() as session: async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id) row = await session.get(FeedbackRow, feedback_id)
if row is None: if row is None:
return None return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id: if resolved_user_id is not None and row.user_id != resolved_user_id:
return None return None
return self._row_to_dict(row) return self._row_to_dict(row)
@@ -78,12 +78,12 @@ class FeedbackRepository:
run_id: str, run_id: str,
*, *,
limit: int = 100, limit: int = 100,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]: ) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run") resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id) stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
if resolved_owner_id is not None: if resolved_user_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id) stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit) stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session: async with self._sf() as session:
result = await session.execute(stmt) result = await session.execute(stmt)
@@ -94,12 +94,12 @@ class FeedbackRepository:
thread_id: str, thread_id: str,
*, *,
limit: int = 100, limit: int = 100,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]: ) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread") resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id) stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
if resolved_owner_id is not None: if resolved_user_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id) stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit) stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session: async with self._sf() as session:
result = await session.execute(stmt) result = await session.execute(stmt)
@@ -109,14 +109,14 @@ class FeedbackRepository:
self, self,
feedback_id: str, feedback_id: str,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> bool: ) -> bool:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete") resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
async with self._sf() as session: async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id) row = await session.get(FeedbackRow, feedback_id)
if row is None: if row is None:
return False return False
if resolved_owner_id is not None and row.owner_id != resolved_owner_id: if resolved_user_id is not None and row.user_id != resolved_user_id:
return False return False
await session.delete(row) await session.delete(row)
await session.commit() await session.commit()
@@ -19,7 +19,7 @@ class RunEventRow(Base):
# Owner of the conversation this event belongs to. Nullable for data # Owner of the conversation this event belongs to. Nullable for data
# created before auth was introduced; populated by auth middleware on # created before auth was introduced; populated by auth middleware on
# new writes and by the boot-time orphan migration on existing rows. # new writes and by the boot-time orphan migration on existing rows.
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True) user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
event_type: Mapped[str] = mapped_column(String(32), nullable=False) event_type: Mapped[str] = mapped_column(String(32), nullable=False)
category: Mapped[str] = mapped_column(String(16), nullable=False) category: Mapped[str] = mapped_column(String(16), nullable=False)
# "message" | "trace" | "lifecycle" # "message" | "trace" | "lifecycle"
@@ -16,7 +16,7 @@ class RunRow(Base):
run_id: Mapped[str] = mapped_column(String(64), primary_key=True) run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True) thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
assistant_id: Mapped[str | None] = mapped_column(String(128)) assistant_id: Mapped[str | None] = mapped_column(String(128))
owner_id: Mapped[str | None] = mapped_column(String(64), index=True) user_id: Mapped[str | None] = mapped_column(String(64), index=True)
status: Mapped[str] = mapped_column(String(20), default="pending") status: Mapped[str] = mapped_column(String(20), default="pending")
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted" # "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
@@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.run.model import RunRow from deerflow.persistence.run.model import RunRow
from deerflow.runtime.runs.store.base import RunStore from deerflow.runtime.runs.store.base import RunStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
class RunRepository(RunStore): class RunRepository(RunStore):
@@ -69,7 +69,7 @@ class RunRepository(RunStore):
*, *,
thread_id, thread_id,
assistant_id=None, assistant_id=None,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
status="pending", status="pending",
multitask_strategy="reject", multitask_strategy="reject",
metadata=None, metadata=None,
@@ -78,13 +78,13 @@ class RunRepository(RunStore):
created_at=None, created_at=None,
follow_up_to_run_id=None, follow_up_to_run_id=None,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put") resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
now = datetime.now(UTC) now = datetime.now(UTC)
row = RunRow( row = RunRow(
run_id=run_id, run_id=run_id,
thread_id=thread_id, thread_id=thread_id,
assistant_id=assistant_id, assistant_id=assistant_id,
owner_id=resolved_owner_id, user_id=resolved_user_id,
status=status, status=status,
multitask_strategy=multitask_strategy, multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {}, metadata_json=self._safe_json(metadata) or {},
@@ -102,14 +102,14 @@ class RunRepository(RunStore):
self, self,
run_id, run_id,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get") resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
async with self._sf() as session: async with self._sf() as session:
row = await session.get(RunRow, run_id) row = await session.get(RunRow, run_id)
if row is None: if row is None:
return None return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id: if resolved_user_id is not None and row.user_id != resolved_user_id:
return None return None
return self._row_to_dict(row) return self._row_to_dict(row)
@@ -117,13 +117,13 @@ class RunRepository(RunStore):
self, self,
thread_id, thread_id,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
limit=100, limit=100,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread") resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
stmt = select(RunRow).where(RunRow.thread_id == thread_id) stmt = select(RunRow).where(RunRow.thread_id == thread_id)
if resolved_owner_id is not None: if resolved_user_id is not None:
stmt = stmt.where(RunRow.owner_id == resolved_owner_id) stmt = stmt.where(RunRow.user_id == resolved_user_id)
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit) stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
async with self._sf() as session: async with self._sf() as session:
result = await session.execute(stmt) result = await session.execute(stmt)
@@ -141,14 +141,14 @@ class RunRepository(RunStore):
self, self,
run_id, run_id,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete") resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
async with self._sf() as session: async with self._sf() as session:
row = await session.get(RunRow, run_id) row = await session.get(RunRow, run_id)
if row is None: if row is None:
return return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id: if resolved_user_id is not None and row.user_id != resolved_user_id:
return return
await session.delete(row) await session.delete(row)
await session.commit() await session.commit()
@@ -15,7 +15,7 @@ class ThreadMetaRow(Base):
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True) thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True) assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
owner_id: Mapped[str | None] = mapped_column(String(64), index=True) user_id: Mapped[str | None] = mapped_column(String(64), index=True)
display_name: Mapped[str | None] = mapped_column(String(256)) display_name: Mapped[str | None] = mapped_column(String(256))
status: Mapped[str] = mapped_column(String(20), default="idle") status: Mapped[str] = mapped_column(String(20), default="idle")
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict) metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.thread_meta.base import ThreadMetaStore from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
class ThreadMetaRepository(ThreadMetaStore): class ThreadMetaRepository(ThreadMetaStore):
@@ -32,18 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
thread_id: str, thread_id: str,
*, *,
assistant_id: str | None = None, assistant_id: str | None = None,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
display_name: str | None = None, display_name: str | None = None,
metadata: dict | None = None, metadata: dict | None = None,
) -> dict: ) -> dict:
# Auto-resolve owner_id from contextvar when AUTO; explicit None # Auto-resolve user_id from contextvar when AUTO; explicit None
# creates an orphan row (used by migration scripts). # creates an orphan row (used by migration scripts).
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create") resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
now = datetime.now(UTC) now = datetime.now(UTC)
row = ThreadMetaRow( row = ThreadMetaRow(
thread_id=thread_id, thread_id=thread_id,
assistant_id=assistant_id, assistant_id=assistant_id,
owner_id=resolved_owner_id, user_id=resolved_user_id,
display_name=display_name, display_name=display_name,
metadata_json=metadata or {}, metadata_json=metadata or {},
created_at=now, created_at=now,
@@ -59,40 +59,34 @@ class ThreadMetaRepository(ThreadMetaStore):
self, self,
thread_id: str, thread_id: str,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> dict | None: ) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get") resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
async with self._sf() as session: async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id) row = await session.get(ThreadMetaRow, thread_id)
if row is None: if row is None:
return None return None
# Enforce owner filter unless explicitly bypassed (owner_id=None). # Enforce owner filter unless explicitly bypassed (user_id=None).
if resolved_owner_id is not None and row.owner_id != resolved_owner_id: if resolved_user_id is not None and row.user_id != resolved_user_id:
return None return None
return self._row_to_dict(row) return self._row_to_dict(row)
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]: async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset) """Check if ``user_id`` has access to ``thread_id``.
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``owner_id`` has access to ``thread_id``.
Two modes — one row, two distinct semantics depending on what Two modes — one row, two distinct semantics depending on what
the caller is about to do: the caller is about to do:
- ``require_existing=False`` (default, permissive): - ``require_existing=False`` (default, permissive):
Returns True for: row missing (untracked legacy thread), Returns True for: row missing (untracked legacy thread),
``row.owner_id`` is None (shared / pre-auth data), ``row.user_id`` is None (shared / pre-auth data),
or ``row.owner_id == owner_id``. Use for **read-style** or ``row.user_id == user_id``. Use for **read-style**
decorators where treating an untracked thread as accessible decorators where treating an untracked thread as accessible
preserves backward-compat. preserves backward-compat.
- ``require_existing=True`` (strict): - ``require_existing=True`` (strict):
Returns True **only** when the row exists AND Returns True **only** when the row exists AND
(``row.owner_id == owner_id`` OR ``row.owner_id is None``). (``row.user_id == user_id`` OR ``row.user_id is None``).
Use for **destructive / mutating** decorators (DELETE, PATCH, Use for **destructive / mutating** decorators (DELETE, PATCH,
state-update) so a thread that has *already been deleted* state-update) so a thread that has *already been deleted*
cannot be re-targeted by any caller — closing the cannot be re-targeted by any caller — closing the
@@ -103,9 +97,9 @@ class ThreadMetaRepository(ThreadMetaStore):
row = await session.get(ThreadMetaRow, thread_id) row = await session.get(ThreadMetaRow, thread_id)
if row is None: if row is None:
return not require_existing return not require_existing
if row.owner_id is None: if row.user_id is None:
return True return True
return row.owner_id == owner_id return row.user_id == user_id
async def search( async def search(
self, self,
@@ -114,17 +108,17 @@ class ThreadMetaRepository(ThreadMetaStore):
status: str | None = None, status: str | None = None,
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]: ) -> list[dict]:
"""Search threads with optional metadata and status filters. """Search threads with optional metadata and status filters.
Owner filter is enforced by default: caller must be in a user Owner filter is enforced by default: caller must be in a user
context. Pass ``owner_id=None`` to bypass (migration/CLI). context. Pass ``user_id=None`` to bypass (migration/CLI).
""" """
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search") resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc()) stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
if resolved_owner_id is not None: if resolved_user_id is not None:
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id) stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
if status: if status:
stmt = stmt.where(ThreadMetaRow.status == status) stmt = stmt.where(ThreadMetaRow.status == status)
@@ -144,24 +138,24 @@ class ThreadMetaRepository(ThreadMetaStore):
result = await session.execute(stmt) result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()] return [self._row_to_dict(r) for r in result.scalars()]
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool: async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
"""Return True if the row exists and is owned (or filter bypassed).""" """Return True if the row exists and is owned (or filter bypassed)."""
if resolved_owner_id is None: if resolved_user_id is None:
return True # explicit bypass return True # explicit bypass
row = await session.get(ThreadMetaRow, thread_id) row = await session.get(ThreadMetaRow, thread_id)
return row is not None and row.owner_id == resolved_owner_id return row is not None and row.user_id == resolved_user_id
async def update_display_name( async def update_display_name(
self, self,
thread_id: str, thread_id: str,
display_name: str, display_name: str,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> None: ) -> None:
"""Update the display_name (title) for a thread.""" """Update the display_name (title) for a thread."""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name") resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
async with self._sf() as session: async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id): if not await self._check_ownership(session, thread_id, resolved_user_id):
return return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC))) await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
await session.commit() await session.commit()
@@ -171,11 +165,11 @@ class ThreadMetaRepository(ThreadMetaStore):
thread_id: str, thread_id: str,
status: str, status: str,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> None: ) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status") resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
async with self._sf() as session: async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id): if not await self._check_ownership(session, thread_id, resolved_user_id):
return return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC))) await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
await session.commit() await session.commit()
@@ -185,20 +179,20 @@ class ThreadMetaRepository(ThreadMetaStore):
thread_id: str, thread_id: str,
metadata: dict, metadata: dict,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> None: ) -> None:
"""Merge ``metadata`` into ``metadata_json``. """Merge ``metadata`` into ``metadata_json``.
Read-modify-write inside a single session/transaction so concurrent Read-modify-write inside a single session/transaction so concurrent
callers see consistent state. No-op if the row does not exist or callers see consistent state. No-op if the row does not exist or
the owner_id check fails. the user_id check fails.
""" """
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata") resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
async with self._sf() as session: async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id) row = await session.get(ThreadMetaRow, thread_id)
if row is None: if row is None:
return return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id: if resolved_user_id is not None and row.user_id != resolved_user_id:
return return
merged = dict(row.metadata_json or {}) merged = dict(row.metadata_json or {})
merged.update(metadata) merged.update(metadata)
@@ -210,14 +204,14 @@ class ThreadMetaRepository(ThreadMetaStore):
self, self,
thread_id: str, thread_id: str,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
) -> None: ) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete") resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
async with self._sf() as session: async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id) row = await session.get(ThreadMetaRow, thread_id)
if row is None: if row is None:
return return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id: if resolved_user_id is not None and row.user_id != resolved_user_id:
return return
await session.delete(row) await session.delete(row)
await session.commit() await session.commit()
@@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.runtime.checkpointer.provider import ( from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED, POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL, POSTGRES_INSTALL,
SQLITE_INSTALL, SQLITE_INSTALL,
) )
from deerflow.config.app_config import get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow from deerflow.persistence.models.run_event import RunEventRow
from deerflow.runtime.events.store.base import RunEventStore from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -55,8 +55,8 @@ class DbRunEventStore(RunEventStore):
return content, metadata or {} return content, metadata or {}
@staticmethod @staticmethod
def _owner_from_context() -> str | None: def _user_id_from_context() -> str | None:
"""Soft read of owner_id from contextvar for write paths. """Soft read of user_id from contextvar for write paths.
Returns ``None`` (no filter / no stamp) if contextvar is unset, Returns ``None`` (no filter / no stamp) if contextvar is unset,
which is the expected case for background worker writes. HTTP which is the expected case for background worker writes. HTTP
@@ -81,7 +81,7 @@ class DbRunEventStore(RunEventStore):
metadata = {**(metadata or {}), "content_is_dict": True} metadata = {**(metadata or {}), "content_is_dict": True}
else: else:
db_content = content db_content = content
owner_id = self._owner_from_context() user_id = self._user_id_from_context()
async with self._sf() as session: async with self._sf() as session:
async with session.begin(): async with session.begin():
# Use FOR UPDATE to serialize seq assignment within a thread. # Use FOR UPDATE to serialize seq assignment within a thread.
@@ -92,7 +92,7 @@ class DbRunEventStore(RunEventStore):
row = RunEventRow( row = RunEventRow(
thread_id=thread_id, thread_id=thread_id,
run_id=run_id, run_id=run_id,
owner_id=owner_id, user_id=user_id,
event_type=event_type, event_type=event_type,
category=category, category=category,
content=db_content, content=db_content,
@@ -106,7 +106,7 @@ class DbRunEventStore(RunEventStore):
async def put_batch(self, events): async def put_batch(self, events):
if not events: if not events:
return [] return []
owner_id = self._owner_from_context() user_id = self._user_id_from_context()
async with self._sf() as session: async with self._sf() as session:
async with session.begin(): async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread). # Get max seq for the thread (assume all events in batch belong to same thread).
@@ -130,7 +130,7 @@ class DbRunEventStore(RunEventStore):
row = RunEventRow( row = RunEventRow(
thread_id=e["thread_id"], thread_id=e["thread_id"],
run_id=e["run_id"], run_id=e["run_id"],
owner_id=e.get("owner_id", owner_id), user_id=e.get("user_id", user_id),
event_type=e["event_type"], event_type=e["event_type"],
category=category, category=category,
content=db_content, content=db_content,
@@ -149,12 +149,12 @@ class DbRunEventStore(RunEventStore):
limit=50, limit=50,
before_seq=None, before_seq=None,
after_seq=None, after_seq=None,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages") resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_owner_id is not None: if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if before_seq is not None: if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq) stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None: if after_seq is not None:
@@ -181,12 +181,12 @@ class DbRunEventStore(RunEventStore):
*, *,
event_types=None, event_types=None,
limit=500, limit=500,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events") resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id) stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
if resolved_owner_id is not None: if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if event_types: if event_types:
stmt = stmt.where(RunEventRow.event_type.in_(event_types)) stmt = stmt.where(RunEventRow.event_type.in_(event_types))
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit) stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
@@ -199,12 +199,12 @@ class DbRunEventStore(RunEventStore):
thread_id, thread_id,
run_id, run_id,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run") resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message") stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
if resolved_owner_id is not None: if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
stmt = stmt.order_by(RunEventRow.seq.asc()) stmt = stmt.order_by(RunEventRow.seq.asc())
async with self._sf() as session: async with self._sf() as session:
result = await session.execute(stmt) result = await session.execute(stmt)
@@ -214,12 +214,12 @@ class DbRunEventStore(RunEventStore):
self, self,
thread_id, thread_id,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages") resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message") stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_owner_id is not None: if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id) stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
async with self._sf() as session: async with self._sf() as session:
return await session.scalar(stmt) or 0 return await session.scalar(stmt) or 0
@@ -227,13 +227,13 @@ class DbRunEventStore(RunEventStore):
self, self,
thread_id, thread_id,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread") resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
async with self._sf() as session: async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id] count_conditions = [RunEventRow.thread_id == thread_id]
if resolved_owner_id is not None: if resolved_user_id is not None:
count_conditions.append(RunEventRow.owner_id == resolved_owner_id) count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions) count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0 count = await session.scalar(count_stmt) or 0
if count > 0: if count > 0:
@@ -246,13 +246,13 @@ class DbRunEventStore(RunEventStore):
thread_id, thread_id,
run_id, run_id,
*, *,
owner_id: str | None | _AutoSentinel = AUTO, user_id: str | None | _AutoSentinel = AUTO,
): ):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run") resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
async with self._sf() as session: async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id] count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
if resolved_owner_id is not None: if resolved_user_id is not None:
count_conditions.append(RunEventRow.owner_id == resolved_owner_id) count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions) count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0 count = await session.scalar(count_stmt) or 0
if count > 0: if count > 0:
@@ -4,8 +4,8 @@ RunManager depends on this interface. Implementations:
- MemoryRunStore: in-memory dict (development, tests) - MemoryRunStore: in-memory dict (development, tests)
- Future: RunRepository backed by SQLAlchemy ORM - Future: RunRepository backed by SQLAlchemy ORM
All methods accept an optional owner_id for user isolation. All methods accept an optional user_id for user isolation.
When owner_id is None, no user filtering is applied (single-user mode). When user_id is None, no user filtering is applied (single-user mode).
""" """
from __future__ import annotations from __future__ import annotations
@@ -22,7 +22,7 @@ class RunStore(abc.ABC):
*, *,
thread_id: str, thread_id: str,
assistant_id: str | None = None, assistant_id: str | None = None,
owner_id: str | None = None, user_id: str | None = None,
status: str = "pending", status: str = "pending",
multitask_strategy: str = "reject", multitask_strategy: str = "reject",
metadata: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None,
@@ -42,7 +42,7 @@ class RunStore(abc.ABC):
self, self,
thread_id: str, thread_id: str,
*, *,
owner_id: str | None = None, user_id: str | None = None,
limit: int = 100, limit: int = 100,
) -> list[dict[str, Any]]: ) -> list[dict[str, Any]]:
pass pass
@@ -21,7 +21,7 @@ class MemoryRunStore(RunStore):
*, *,
thread_id, thread_id,
assistant_id=None, assistant_id=None,
owner_id=None, user_id=None,
status="pending", status="pending",
multitask_strategy="reject", multitask_strategy="reject",
metadata=None, metadata=None,
@@ -35,7 +35,7 @@ class MemoryRunStore(RunStore):
"run_id": run_id, "run_id": run_id,
"thread_id": thread_id, "thread_id": thread_id,
"assistant_id": assistant_id, "assistant_id": assistant_id,
"owner_id": owner_id, "user_id": user_id,
"status": status, "status": status,
"multitask_strategy": multitask_strategy, "multitask_strategy": multitask_strategy,
"metadata": metadata or {}, "metadata": metadata or {},
@@ -49,8 +49,8 @@ class MemoryRunStore(RunStore):
async def get(self, run_id): async def get(self, run_id):
return self._runs.get(run_id) return self._runs.get(run_id)
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100): async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)] results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
results.sort(key=lambda r: r["created_at"], reverse=True) results.sort(key=lambda r: r["created_at"], reverse=True)
return results[:limit] return results[:limit]
@@ -50,7 +50,7 @@ class RunContext:
store: Any | None = field(default=None) store: Any | None = field(default=None)
event_store: Any | None = field(default=None) event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None) run_events_config: Any | None = field(default=None)
thread_meta_repo: Any | None = field(default=None) thread_store: Any | None = field(default=None)
follow_up_to_run_id: str | None = field(default=None) follow_up_to_run_id: str | None = field(default=None)
@@ -75,7 +75,7 @@ async def run_agent(
store = ctx.store store = ctx.store
event_store = ctx.event_store event_store = ctx.event_store
run_events_config = ctx.run_events_config run_events_config = ctx.run_events_config
thread_meta_repo = ctx.thread_meta_repo thread_store = ctx.thread_store
follow_up_to_run_id = ctx.follow_up_to_run_id follow_up_to_run_id = ctx.follow_up_to_run_id
run_id = record.run_id run_id = record.run_id
@@ -376,14 +376,14 @@ async def run_agent(
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {} ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
title = ckpt.get("channel_values", {}).get("title") title = ckpt.get("channel_values", {}).get("title")
if title: if title:
await thread_meta_repo.update_display_name(thread_id, title) await thread_store.update_display_name(thread_id, title)
except Exception: except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id) logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
# Update threads_meta status based on run outcome # Update threads_meta status based on run outcome
try: try:
final_status = "idle" if record.status == RunStatus.success else record.status.value final_status = "idle" if record.status == RunStatus.success else record.status.value
await thread_meta_repo.update_status(thread_id, final_status) await thread_store.update_status(thread_id, final_status)
except Exception: except Exception:
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id) logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
@@ -1,11 +1,11 @@
"""Request-scoped user context for owner-based authorization. """Request-scoped user context for user-based authorization.
This module holds a :class:`~contextvars.ContextVar` that the gateway's This module holds a :class:`~contextvars.ContextVar` that the gateway's
auth middleware sets after a successful authentication. Repository auth middleware sets after a successful authentication. Repository
methods read the contextvar via a sentinel default parameter, letting methods read the contextvar via a sentinel default parameter, letting
routers stay free of ``owner_id`` boilerplate. routers stay free of ``user_id`` boilerplate.
Three-state semantics for the repository ``owner_id`` parameter (the Three-state semantics for the repository ``user_id`` parameter (the
consumer side of this module lives in ``deerflow.persistence.*``): consumer side of this module lives in ``deerflow.persistence.*``):
- ``_AUTO`` (module-private sentinel, default): read from contextvar; - ``_AUTO`` (module-private sentinel, default): read from contextvar;
@@ -91,16 +91,16 @@ def require_current_user() -> CurrentUser:
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Sentinel-based owner_id resolution # Sentinel-based user_id resolution
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# #
# Repository methods accept an ``owner_id`` keyword-only argument that # Repository methods accept a ``user_id`` keyword-only argument that
# defaults to ``AUTO``. The three possible values drive distinct # defaults to ``AUTO``. The three possible values drive distinct
# behaviours; see the docstring on :func:`resolve_owner_id`. # behaviours; see the docstring on :func:`resolve_user_id`.
class _AutoSentinel: class _AutoSentinel:
"""Singleton marker meaning 'resolve owner_id from contextvar'.""" """Singleton marker meaning 'resolve user_id from contextvar'."""
_instance: _AutoSentinel | None = None _instance: _AutoSentinel | None = None
@@ -116,12 +116,12 @@ class _AutoSentinel:
AUTO: Final[_AutoSentinel] = _AutoSentinel() AUTO: Final[_AutoSentinel] = _AutoSentinel()
def resolve_owner_id( def resolve_user_id(
value: str | None | _AutoSentinel, value: str | None | _AutoSentinel,
*, *,
method_name: str = "repository method", method_name: str = "repository method",
) -> str | None: ) -> str | None:
"""Resolve the owner_id parameter passed to a repository method. """Resolve the user_id parameter passed to a repository method.
Three-state semantics: Three-state semantics:
@@ -131,16 +131,16 @@ def resolve_owner_id(
- Explicit ``str``: use the provided id verbatim, overriding any - Explicit ``str``: use the provided id verbatim, overriding any
contextvar value. Useful for tests and admin-override flows. contextvar value. Useful for tests and admin-override flows.
- Explicit ``None``: no filter — the repository should skip the - Explicit ``None``: no filter — the repository should skip the
owner_id WHERE clause entirely. Reserved for migration scripts user_id WHERE clause entirely. Reserved for migration scripts
and CLI tools that intentionally bypass isolation. and CLI tools that intentionally bypass isolation.
""" """
if isinstance(value, _AutoSentinel): if isinstance(value, _AutoSentinel):
user = _current_user.get() user = _current_user.get()
if user is None: if user is None:
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.") raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
# Coerce to ``str`` at the boundary: ``User.id`` is typed as # Coerce to ``str`` at the boundary: ``User.id`` is typed as
# ``UUID`` for the API surface, but the persistence layer # ``UUID`` for the API surface, but the persistence layer
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot # stores ``user_id`` as ``String(64)`` and aiosqlite cannot
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is # bind a raw UUID object to a VARCHAR column ("type 'UUID' is
# not supported"). Honour the documented return type here # not supported"). Honour the documented return type here
# rather than ripple a type change through every caller. # rather than ripple a type change through every caller.
+7 -7
View File
@@ -3,16 +3,16 @@
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie) The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
ahead of every router, plus ``@require_permission(owner_check=True)`` ahead of every router, plus ``@require_permission(owner_check=True)``
decorators that read ``request.state.auth`` and call decorators that read ``request.state.auth`` and call
``thread_meta_repo.check_access``. Router-level unit tests construct ``thread_store.check_access``. Router-level unit tests construct
**bare** FastAPI apps that include only one router — they have neither **bare** FastAPI apps that include only one router — they have neither
the auth middleware nor a real thread_meta_repo, so the decorators raise the auth middleware nor a real thread_store, so the decorators raise
401 (TestClient path) or ValueError (direct-call path). 401 (TestClient path) or ValueError (direct-call path).
This module provides two surfaces: This module provides two surfaces:
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny 1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every ``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
request, plus a permissive ``thread_meta_repo`` mock on request, plus a permissive ``thread_store`` mock on
``app.state``. Use from TestClient-based router tests. ``app.state``. Use from TestClient-based router tests.
2. :func:`call_unwrapped` — invokes the underlying function bypassing 2. :func:`call_unwrapped` — invokes the underlying function bypassing
@@ -86,20 +86,20 @@ def make_authed_test_app(
user_factory: Callable[[], User] | None = None, user_factory: Callable[[], User] | None = None,
owner_check_passes: bool = True, owner_check_passes: bool = True,
) -> FastAPI: ) -> FastAPI:
"""Build a FastAPI test app with stub auth + permissive thread_meta_repo. """Build a FastAPI test app with stub auth + permissive thread_store.
Args: Args:
user_factory: Override the default test user. Must return a fully user_factory: Override the default test user. Must return a fully
populated :class:`User`. Useful for cross-user isolation tests populated :class:`User`. Useful for cross-user isolation tests
that need a stable id across requests. that need a stable id across requests.
owner_check_passes: When True (default), ``thread_meta_repo.check_access`` owner_check_passes: When True (default), ``thread_store.check_access``
returns True for every call so ``@require_permission(owner_check=True)`` returns True for every call so ``@require_permission(owner_check=True)``
never blocks the route under test. Pass False to verify that never blocks the route under test. Pass False to verify that
permission failures surface correctly. permission failures surface correctly.
Returns: Returns:
A ``FastAPI`` app with the stub middleware installed and A ``FastAPI`` app with the stub middleware installed and
``app.state.thread_meta_repo`` set to a permissive mock. The ``app.state.thread_store`` set to a permissive mock. The
caller is still responsible for ``app.include_router(...)``. caller is still responsible for ``app.include_router(...)``.
""" """
factory = user_factory or _make_stub_user factory = user_factory or _make_stub_user
@@ -108,7 +108,7 @@ def make_authed_test_app(
repo = MagicMock() repo = MagicMock()
repo.check_access = AsyncMock(return_value=owner_check_passes) repo.check_access = AsyncMock(return_value=owner_check_passes)
app.state.thread_meta_repo = repo app.state.thread_store = repo
return app return app
+1 -1
View File
@@ -60,7 +60,7 @@ def provisioner_module():
# Auto-set user context for every test unless marked no_auto_user # Auto-set user context for every test unless marked no_auto_user
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# #
# Repository methods read ``owner_id`` from a contextvar by default # Repository methods read ``user_id`` from a contextvar by default
# (see ``deerflow.runtime.user_context``). Without this fixture, every # (see ``deerflow.runtime.user_context``). Without this fixture, every
# pre-existing persistence test would raise RuntimeError because the # pre-existing persistence test would raise RuntimeError because the
# contextvar is unset. The fixture sets a default test user on every # contextvar is unset. The fixture sets a default test user on every
+1 -1
View File
@@ -6,13 +6,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import deerflow.config.app_config as app_config_module import deerflow.config.app_config as app_config_module
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.config.checkpointer_config import ( from deerflow.config.checkpointer_config import (
CheckpointerConfig, CheckpointerConfig,
get_checkpointer_config, get_checkpointer_config,
load_checkpointer_config_from_dict, load_checkpointer_config_from_dict,
set_checkpointer_config, set_checkpointer_config,
) )
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
+6 -6
View File
@@ -199,12 +199,12 @@ def test_migration_failure_is_non_fatal():
# ── Section 5.1-5.6 upgrade path: orphan thread migration ──────────────── # ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows(): def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
"""First boot finds Store-only legacy threads → stamps admin's id. """First boot finds Store-only legacy threads → stamps admin's id.
Validates the **TC-UPG-02 upgrade story**: an operator running main Validates the **TC-UPG-02 upgrade story**: an operator running main
(no auth) accumulates threads in the LangGraph Store namespace (no auth) accumulates threads in the LangGraph Store namespace
``("threads",)`` with no ``metadata.owner_id``. After upgrading to ``("threads",)`` with no ``metadata.user_id``. After upgrading to
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
rewrite each unowned item with the freshly created admin's id. rewrite each unowned item with the freshly created admin's id.
""" """
@@ -215,7 +215,7 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}), SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}), SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
SimpleNamespace(key="t3", value={"metadata": {}}), SimpleNamespace(key="t3", value={"metadata": {}}),
SimpleNamespace(key="t4", value={"metadata": {"owner_id": "someone-else", "title": "preserved"}}), SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
] ]
store = AsyncMock() store = AsyncMock()
# asearch returns the entire batch on first call, then an empty page # asearch returns the entire batch on first call, then an empty page
@@ -235,11 +235,11 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
assert len(aput_calls) == 3 assert len(aput_calls) == 3
rewritten_keys = {call[1] for call in aput_calls} rewritten_keys = {call[1] for call in aput_calls}
assert rewritten_keys == {"t1", "t2", "t3"} assert rewritten_keys == {"t1", "t2", "t3"}
# Each rewrite carries the new owner_id; titles preserved where present. # Each rewrite carries the new user_id; titles preserved where present.
by_key = {call[1]: call[2] for call in aput_calls} by_key = {call[1]: call[2] for call in aput_calls}
assert by_key["t1"]["metadata"]["owner_id"] == "admin-id-42" assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
assert by_key["t1"]["metadata"]["title"] == "old-thread-1" assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
assert by_key["t3"]["metadata"]["owner_id"] == "admin-id-42" assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
# The pre-owned item must NOT have been rewritten. # The pre-owned item must NOT have been rewritten.
assert "t4" not in rewritten_keys assert "t4" not in rewritten_keys
+2 -2
View File
@@ -60,8 +60,8 @@ class TestFeedbackRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_create_with_owner(self, tmp_path): async def test_create_with_owner(self, tmp_path):
repo = await _make_feedback_repo(tmp_path) repo = await _make_feedback_repo(tmp_path)
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1") record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
assert record["owner_id"] == "user-1" assert record["user_id"] == "user-1"
await _cleanup() await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
+9 -9
View File
@@ -175,46 +175,46 @@ def _make_ctx(user_id):
def test_filter_injects_user_id(): def test_filter_injects_user_id():
value = {} value = {}
asyncio.run(add_owner_filter(_make_ctx("user-a"), value)) asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
assert value["metadata"]["owner_id"] == "user-a" assert value["metadata"]["user_id"] == "user-a"
def test_filter_preserves_existing_metadata(): def test_filter_preserves_existing_metadata():
value = {"metadata": {"title": "hello"}} value = {"metadata": {"title": "hello"}}
asyncio.run(add_owner_filter(_make_ctx("user-a"), value)) asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
assert value["metadata"]["owner_id"] == "user-a" assert value["metadata"]["user_id"] == "user-a"
assert value["metadata"]["title"] == "hello" assert value["metadata"]["title"] == "hello"
def test_filter_returns_user_id_dict(): def test_filter_returns_user_id_dict():
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {})) result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
assert result == {"owner_id": "user-x"} assert result == {"user_id": "user-x"}
def test_filter_read_write_consistency(): def test_filter_read_write_consistency():
value = {} value = {}
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value)) filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
assert value["metadata"]["owner_id"] == filter_dict["owner_id"] assert value["metadata"]["user_id"] == filter_dict["user_id"]
def test_different_users_different_filters(): def test_different_users_different_filters():
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {})) f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {})) f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
assert f_a["owner_id"] != f_b["owner_id"] assert f_a["user_id"] != f_b["user_id"]
def test_filter_overrides_conflicting_user_id(): def test_filter_overrides_conflicting_user_id():
"""If value already has a different user_id in metadata, it gets overwritten.""" """If value already has a different user_id in metadata, it gets overwritten."""
value = {"metadata": {"owner_id": "attacker"}} value = {"metadata": {"user_id": "attacker"}}
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value)) asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
assert value["metadata"]["owner_id"] == "real-owner" assert value["metadata"]["user_id"] == "real-owner"
def test_filter_with_empty_metadata(): def test_filter_with_empty_metadata():
"""Explicit empty metadata dict is fine.""" """Explicit empty metadata dict is fine."""
value = {"metadata": {}} value = {"metadata": {}}
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value)) result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
assert value["metadata"]["owner_id"] == "user-z" assert value["metadata"]["user_id"] == "user-z"
assert result == {"owner_id": "user-z"} assert result == {"user_id": "user-z"}
# ── Gateway parity ─────────────────────────────────────────────────────── # ── Gateway parity ───────────────────────────────────────────────────────
+7 -7
View File
@@ -9,8 +9,8 @@ These tests bypass the HTTP layer and exercise the storage-layer
owner filter directly by switching the ``user_context`` contextvar owner filter directly by switching the ``user_context`` contextvar
between two users. The safety property under test is: between two users. The safety property under test is:
After a repository write with owner_id=A, a subsequent read with After a repository write with user_id=A, a subsequent read with
owner_id=B must not return the row, and vice versa. user_id=B must not return the row, and vice versa.
The HTTP layer is covered by test_auth_middleware.py, which proves The HTTP layer is covered by test_auth_middleware.py, which proves
that a request cookie reaches the ``set_current_user`` call. Together that a request cookie reaches the ``set_current_user`` call. Together
@@ -431,13 +431,13 @@ async def test_repository_without_context_raises(tmp_path):
await cleanup() await cleanup()
# ── Escape hatch: explicit owner_id=None bypasses filter (for migration) ── # ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
@pytest.mark.anyio @pytest.mark.anyio
@pytest.mark.no_auto_user @pytest.mark.no_auto_user
async def test_explicit_none_bypasses_filter(tmp_path): async def test_explicit_none_bypasses_filter(tmp_path):
"""Migration scripts pass owner_id=None to see all rows regardless of owner.""" """Migration scripts pass user_id=None to see all rows regardless of owner."""
from deerflow.persistence.engine import get_session_factory from deerflow.persistence.engine import get_session_factory
from deerflow.persistence.thread_meta import ThreadMetaRepository from deerflow.persistence.thread_meta import ThreadMetaRepository
@@ -452,14 +452,14 @@ async def test_explicit_none_bypasses_filter(tmp_path):
await repo.create("t-beta") await repo.create("t-beta")
# Migration-style read: no contextvar, explicit None bypass. # Migration-style read: no contextvar, explicit None bypass.
all_rows = await repo.search(owner_id=None) all_rows = await repo.search(user_id=None)
thread_ids = {r["thread_id"] for r in all_rows} thread_ids = {r["thread_id"] for r in all_rows}
assert thread_ids == {"t-alpha", "t-beta"} assert thread_ids == {"t-alpha", "t-beta"}
# Explicit get with None does not apply the filter either. # Explicit get with None does not apply the filter either.
row_a = await repo.get("t-alpha", owner_id=None) row_a = await repo.get("t-alpha", user_id=None)
assert row_a is not None assert row_a is not None
row_b = await repo.get("t-beta", owner_id=None) row_b = await repo.get("t-beta", user_id=None)
assert row_b is not None assert row_b is not None
finally: finally:
await cleanup() await cleanup()
+8 -8
View File
@@ -2,7 +2,7 @@
Tests: Tests:
1. DatabaseConfig property derivation (paths, URLs) 1. DatabaseConfig property derivation (paths, URLs)
2. MemoryRunStore CRUD + owner_id filtering 2. MemoryRunStore CRUD + user_id filtering
3. Base.to_dict() via inspect mixin 3. Base.to_dict() via inspect mixin
4. Engine init/close lifecycle (memory + SQLite) 4. Engine init/close lifecycle (memory + SQLite)
5. Postgres missing-dep error message 5. Postgres missing-dep error message
@@ -106,17 +106,17 @@ class TestMemoryRunStore:
@pytest.mark.anyio @pytest.mark.anyio
async def test_list_by_thread_owner_filter(self, store): async def test_list_by_thread_owner_filter(self, store):
await store.put("r1", thread_id="t1", owner_id="alice") await store.put("r1", thread_id="t1", user_id="alice")
await store.put("r2", thread_id="t1", owner_id="bob") await store.put("r2", thread_id="t1", user_id="bob")
rows = await store.list_by_thread("t1", owner_id="alice") rows = await store.list_by_thread("t1", user_id="alice")
assert len(rows) == 1 assert len(rows) == 1
assert rows[0]["owner_id"] == "alice" assert rows[0]["user_id"] == "alice"
@pytest.mark.anyio @pytest.mark.anyio
async def test_owner_none_returns_all(self, store): async def test_owner_none_returns_all(self, store):
await store.put("r1", thread_id="t1", owner_id="alice") await store.put("r1", thread_id="t1", user_id="alice")
await store.put("r2", thread_id="t1", owner_id="bob") await store.put("r2", thread_id="t1", user_id="bob")
rows = await store.list_by_thread("t1", owner_id=None) rows = await store.list_by_thread("t1", user_id=None)
assert len(rows) == 2 assert len(rows) == 2
@pytest.mark.anyio @pytest.mark.anyio
+7 -7
View File
@@ -73,11 +73,11 @@ class TestRunRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_list_by_thread_owner_filter(self, tmp_path): async def test_list_by_thread_owner_filter(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", owner_id="alice") await repo.put("r1", thread_id="t1", user_id="alice")
await repo.put("r2", thread_id="t1", owner_id="bob") await repo.put("r2", thread_id="t1", user_id="bob")
rows = await repo.list_by_thread("t1", owner_id="alice") rows = await repo.list_by_thread("t1", user_id="alice")
assert len(rows) == 1 assert len(rows) == 1
assert rows[0]["owner_id"] == "alice" assert rows[0]["user_id"] == "alice"
await _cleanup() await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
@@ -189,8 +189,8 @@ class TestRunRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_owner_none_returns_all(self, tmp_path): async def test_owner_none_returns_all(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", owner_id="alice") await repo.put("r1", thread_id="t1", user_id="alice")
await repo.put("r2", thread_id="t1", owner_id="bob") await repo.put("r2", thread_id="t1", user_id="bob")
rows = await repo.list_by_thread("t1", owner_id=None) rows = await repo.list_by_thread("t1", user_id=None)
assert len(rows) == 2 assert len(rows) == 2
await _cleanup() await _cleanup()
+4 -4
View File
@@ -47,7 +47,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
# Bypass the require_permission decorator (which needs request + # Bypass the require_permission decorator (which needs request +
# thread_meta_repo) — these tests cover the parsing logic. # thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
assert result.suggestions == ["Q1", "Q2", "Q3"] assert result.suggestions == ["Q1", "Q2", "Q3"]
@@ -67,7 +67,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
# Bypass the require_permission decorator (which needs request + # Bypass the require_permission decorator (which needs request +
# thread_meta_repo) — these tests cover the parsing logic. # thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
assert result.suggestions == ["Q1", "Q2"] assert result.suggestions == ["Q1", "Q2"]
@@ -87,7 +87,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
# Bypass the require_permission decorator (which needs request + # Bypass the require_permission decorator (which needs request +
# thread_meta_repo) — these tests cover the parsing logic. # thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
assert result.suggestions == ["Q1", "Q2"] assert result.suggestions == ["Q1", "Q2"]
@@ -104,7 +104,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model) monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
# Bypass the require_permission decorator (which needs request + # Bypass the require_permission decorator (which needs request +
# thread_meta_repo) — these tests cover the parsing logic. # thread_store) — these tests cover the parsing logic.
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None)) result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
assert result.suggestions == [] assert result.suggestions == []
+10 -30
View File
@@ -43,8 +43,8 @@ class TestThreadMetaRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_create_with_owner_and_display_name(self, tmp_path): async def test_create_with_owner_and_display_name(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
record = await repo.create("t1", owner_id="user1", display_name="My Thread") record = await repo.create("t1", user_id="user1", display_name="My Thread")
assert record["owner_id"] == "user1" assert record["user_id"] == "user1"
assert record["display_name"] == "My Thread" assert record["display_name"] == "My Thread"
await _cleanup() await _cleanup()
@@ -61,26 +61,6 @@ class TestThreadMetaRepository:
assert await repo.get("nonexistent") is None assert await repo.get("nonexistent") is None
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_list_by_owner(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.create("t1", owner_id="user1")
await repo.create("t2", owner_id="user1")
await repo.create("t3", owner_id="user2")
results = await repo.list_by_owner("user1")
assert len(results) == 2
assert all(r["owner_id"] == "user1" for r in results)
await _cleanup()
@pytest.mark.anyio
async def test_list_by_owner_with_limit_and_offset(self, tmp_path):
repo = await _make_repo(tmp_path)
for i in range(5):
await repo.create(f"t{i}", owner_id="user1")
results = await repo.list_by_owner("user1", limit=2, offset=1)
assert len(results) == 2
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_check_access_no_record_allows(self, tmp_path): async def test_check_access_no_record_allows(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -90,23 +70,23 @@ class TestThreadMetaRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_check_access_owner_matches(self, tmp_path): async def test_check_access_owner_matches(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.create("t1", owner_id="user1") await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user1") is True assert await repo.check_access("t1", "user1") is True
await _cleanup() await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_check_access_owner_mismatch(self, tmp_path): async def test_check_access_owner_mismatch(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.create("t1", owner_id="user1") await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user2") is False assert await repo.check_access("t1", "user2") is False
await _cleanup() await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_check_access_no_owner_allows_all(self, tmp_path): async def test_check_access_no_owner_allows_all(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
# Explicit owner_id=None to bypass the new AUTO default that # Explicit user_id=None to bypass the new AUTO default that
# would otherwise pick up the test user from the autouse fixture. # would otherwise pick up the test user from the autouse fixture.
await repo.create("t1", owner_id=None) await repo.create("t1", user_id=None)
assert await repo.check_access("t1", "anyone") is True assert await repo.check_access("t1", "anyone") is True
await _cleanup() await _cleanup()
@@ -125,27 +105,27 @@ class TestThreadMetaRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_check_access_strict_owner_match_allowed(self, tmp_path): async def test_check_access_strict_owner_match_allowed(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.create("t1", owner_id="user1") await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user1", require_existing=True) is True assert await repo.check_access("t1", "user1", require_existing=True) is True
await _cleanup() await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path): async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.create("t1", owner_id="user1") await repo.create("t1", user_id="user1")
assert await repo.check_access("t1", "user2", require_existing=True) is False assert await repo.check_access("t1", "user2", require_existing=True) is False
await _cleanup() await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path): async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
"""Even in strict mode, a row with NULL owner_id stays shared. """Even in strict mode, a row with NULL user_id stays shared.
The strict flag tightens the *missing row* case, not the *shared The strict flag tightens the *missing row* case, not the *shared
row* case legacy pre-auth rows that survived a clean migration row* case legacy pre-auth rows that survived a clean migration
without an owner are still everyone's. without an owner are still everyone's.
""" """
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
await repo.create("t1", owner_id=None) await repo.create("t1", user_id=None)
assert await repo.check_access("t1", "anyone", require_existing=True) is True assert await repo.check_access("t1", "anyone", require_existing=True) is True
await _cleanup() await _cleanup()
+3 -9
View File
@@ -113,14 +113,8 @@ def test_delete_thread_data_returns_generic_500_error(tmp_path):
# ── Server-reserved metadata key stripping ────────────────────────────────── # ── Server-reserved metadata key stripping ──────────────────────────────────
def test_strip_reserved_metadata_removes_owner_id():
"""Client-supplied owner_id is dropped to prevent reflection attacks."""
out = threads._strip_reserved_metadata({"owner_id": "victim-id", "title": "ok"})
assert out == {"title": "ok"}
def test_strip_reserved_metadata_removes_user_id(): def test_strip_reserved_metadata_removes_user_id():
"""user_id is also reserved (defense in depth for any future use).""" """Client-supplied user_id is dropped to prevent reflection attacks."""
out = threads._strip_reserved_metadata({"user_id": "victim-id", "title": "ok"}) out = threads._strip_reserved_metadata({"user_id": "victim-id", "title": "ok"})
assert out == {"title": "ok"} assert out == {"title": "ok"}
@@ -136,6 +130,6 @@ def test_strip_reserved_metadata_empty_input():
assert threads._strip_reserved_metadata({}) == {} assert threads._strip_reserved_metadata({}) == {}
def test_strip_reserved_metadata_strips_both_simultaneously(): def test_strip_reserved_metadata_strips_all_reserved_keys():
out = threads._strip_reserved_metadata({"owner_id": "x", "user_id": "y", "keep": "me"}) out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"})
assert out == {"keep": "me"} assert out == {"keep": "me"}