feat(auth): enforce owner_id across 2.0-rc persistence layer

Add request-scoped contextvar-based owner filtering to threads_meta,
runs, run_events, and feedback repositories. Router code is unchanged
— isolation is enforced at the storage layer so that any caller that
forgets to pass owner_id still gets filtered results, and new routes
cannot accidentally leak data.

Core infrastructure
-------------------
- deerflow/runtime/user_context.py (new):
  - ContextVar[CurrentUser | None] with default None
  - runtime_checkable CurrentUser Protocol (structural subtype with .id)
  - set/reset/get/require helpers
  - AUTO sentinel + resolve_owner_id(value, method_name) for sentinel
    three-state resolution: AUTO reads contextvar, explicit str
    overrides, explicit None bypasses the filter (for migration/CLI)

Repository changes
------------------
- ThreadMetaRepository: create/get/search/update_*/delete gain
  owner_id=AUTO kwarg; read paths filter by owner, writes stamp it,
  mutations check ownership before applying
- RunRepository: put/get/list_by_thread/delete gain owner_id=AUTO kwarg
- FeedbackRepository: create/get/list_by_run/list_by_thread/delete
  gain owner_id=AUTO kwarg
- DbRunEventStore: list_messages/list_events/list_messages_by_run/
  count_messages/delete_by_thread/delete_by_run gain owner_id=AUTO
  kwarg. Write paths (put/put_batch) read contextvar softly: when a
  request-scoped user is available, owner_id is stamped; background
  worker writes without a user context pass None which is valid
  (orphan row to be bound by migration)

Schema
------
- persistence/models/run_event.py: RunEventRow.owner_id = Mapped[
  str | None] = mapped_column(String(64), nullable=True, index=True)
- No alembic migration needed: 2.0 ships fresh, Base.metadata.create_all
  picks up the new column automatically

Middleware
----------
- auth_middleware.py: after cookie check, call get_optional_user_from_
  request to load the real User, stamp it into request.state.user AND
  the contextvar via set_current_user, reset in a try/finally. Public
  paths and unauthenticated requests continue without contextvar, and
  @require_auth handles the strict 401 path

Test infrastructure
-------------------
- tests/conftest.py: @pytest.fixture(autouse=True) _auto_user_context
  sets a default SimpleNamespace(id="test-user-autouse") on every test
  unless marked @pytest.mark.no_auto_user. Keeps existing 20+
  persistence tests passing without modification
- pyproject.toml [tool.pytest.ini_options]: register no_auto_user
  marker so pytest does not emit warnings for opt-out tests
- tests/test_user_context.py: 6 tests covering three-state semantics,
  Protocol duck typing, and require/optional APIs
- tests/test_thread_meta_repo.py: one test updated to pass owner_id=
  None explicitly where it was previously relying on the old default

Test results
------------
- test_user_context.py: 6 passed
- test_auth*.py + test_langgraph_auth.py + test_ensure_admin.py: 127
- test_run_event_store / test_run_repository / test_thread_meta_repo
  / test_feedback: 92 passed
- Full backend suite: 1905 passed, 2 failed (both @requires_llm flaky
  integration tests unrelated to auth), 1 skipped
This commit is contained in:
greatmengqi
2026-04-08 11:08:48 +08:00
parent 2531cce0d1
commit 4b139fb689
11 changed files with 566 additions and 58 deletions
@@ -12,6 +12,7 @@ from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
class FeedbackRepository:
@@ -32,18 +33,19 @@ class FeedbackRepository:
run_id: str,
thread_id: str,
rating: int,
owner_id: str | None = None,
owner_id: "str | None | _AutoSentinel" = AUTO,
message_id: str | None = None,
comment: str | None = None,
) -> dict:
"""Create a feedback record. rating must be +1 or -1."""
if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}")
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
row = FeedbackRow(
feedback_id=str(uuid.uuid4()),
run_id=run_id,
thread_id=thread_id,
owner_id=owner_id,
owner_id=resolved_owner_id,
message_id=message_id,
rating=rating,
comment=comment,
@@ -55,27 +57,66 @@ class FeedbackRepository:
await session.refresh(row)
return self._row_to_dict(row)
async def get(self, feedback_id: str) -> dict | None:
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
return self._row_to_dict(row) if row else None
async def list_by_run(self, thread_id: str, run_id: str, *, limit: int = 100) -> list[dict]:
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict]:
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def delete(self, feedback_id: str) -> bool:
async def get(
self,
feedback_id: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return None
return self._row_to_dict(row)
async def list_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 100,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
if resolved_owner_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_by_thread(
self,
thread_id: str,
*,
limit: int = 100,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
if resolved_owner_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def delete(
self,
feedback_id: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> bool:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return False
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return False
await session.delete(row)
await session.commit()
@@ -16,6 +16,10 @@ class RunEventRow(Base):
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
# Owner of the conversation this event belongs to. Nullable for data
# created before auth was introduced; populated by auth middleware on
# new writes and by the boot-time orphan migration on existing rows.
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
category: Mapped[str] = mapped_column(String(16), nullable=False)
# "message" | "trace" | "lifecycle"
@@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.run.model import RunRow
from deerflow.runtime.runs.store.base import RunStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
class RunRepository(RunStore):
@@ -68,7 +69,7 @@ class RunRepository(RunStore):
*,
thread_id,
assistant_id=None,
owner_id=None,
owner_id: "str | None | _AutoSentinel" = AUTO,
status="pending",
multitask_strategy="reject",
metadata=None,
@@ -77,12 +78,13 @@ class RunRepository(RunStore):
created_at=None,
follow_up_to_run_id=None,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
now = datetime.now(UTC)
row = RunRow(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=owner_id,
owner_id=resolved_owner_id,
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
@@ -96,15 +98,32 @@ class RunRepository(RunStore):
session.add(row)
await session.commit()
async def get(self, run_id):
async def get(
self,
run_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
return self._row_to_dict(row) if row else None
if row is None:
return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return None
return self._row_to_dict(row)
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
async def list_by_thread(
self,
thread_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
limit=100,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
if owner_id is not None:
stmt = stmt.where(RunRow.owner_id == owner_id)
if resolved_owner_id is not None:
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
@@ -118,12 +137,21 @@ class RunRepository(RunStore):
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def delete(self, run_id):
async def delete(
self,
run_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
if row is not None:
await session.delete(row)
await session.commit()
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return
await session.delete(row)
await session.commit()
async def list_pending(self, *, before=None):
if before is None:
@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
class ThreadMetaRepository(ThreadMetaStore):
@@ -31,15 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
thread_id: str,
*,
assistant_id: str | None = None,
owner_id: str | None = None,
owner_id: "str | None | _AutoSentinel" = AUTO,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
# Auto-resolve owner_id from contextvar when AUTO; explicit None
# creates an orphan row (used by migration scripts).
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
now = datetime.now(UTC)
row = ThreadMetaRow(
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=owner_id,
owner_id=resolved_owner_id,
display_name=display_name,
metadata_json=metadata or {},
created_at=now,
@@ -51,10 +55,21 @@ class ThreadMetaRepository(ThreadMetaStore):
await session.refresh(row)
return self._row_to_dict(row)
async def get(self, thread_id: str) -> dict | None:
async def get(
self,
thread_id: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
return self._row_to_dict(row) if row else None
if row is None:
return None
# Enforce owner filter unless explicitly bypassed (owner_id=None).
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return None
return self._row_to_dict(row)
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
@@ -83,9 +98,17 @@ class ThreadMetaRepository(ThreadMetaStore):
status: str | None = None,
limit: int = 100,
offset: int = 0,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> list[dict]:
"""Search threads with optional metadata and status filters."""
"""Search threads with optional metadata and status filters.
Owner filter is enforced by default: caller must be in a user
context. Pass ``owner_id=None`` to bypass (migration/CLI).
"""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
if resolved_owner_id is not None:
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
if status:
stmt = stmt.where(ThreadMetaRow.status == status)
@@ -105,36 +128,80 @@ class ThreadMetaRepository(ThreadMetaStore):
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_display_name(self, thread_id: str, display_name: str) -> None:
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
"""Return True if the row exists and is owned (or filter bypassed)."""
if resolved_owner_id is None:
return True # explicit bypass
row = await session.get(ThreadMetaRow, thread_id)
return row is not None and row.owner_id == resolved_owner_id
async def update_display_name(
self,
thread_id: str,
display_name: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> None:
"""Update the display_name (title) for a thread."""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
await session.commit()
async def update_status(self, thread_id: str, status: str) -> None:
async def update_status(
self,
thread_id: str,
status: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
await session.commit()
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
async def update_metadata(
self,
thread_id: str,
metadata: dict,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> None:
"""Merge ``metadata`` into ``metadata_json``.
Read-modify-write inside a single session/transaction so concurrent
callers see consistent state. No-op if the row does not exist.
callers see consistent state. No-op if the row does not exist or
the owner_id check fails.
"""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return
merged = dict(row.metadata_json or {})
merged.update(metadata)
row.metadata_json = merged
row.updated_at = datetime.now(UTC)
await session.commit()
async def delete(self, thread_id: str) -> None:
async def delete(
self,
thread_id: str,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is not None:
await session.delete(row)
await session.commit()
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
return
await session.delete(row)
await session.commit()
@@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
logger = logging.getLogger(__name__)
@@ -53,6 +54,18 @@ class DbRunEventStore(RunEventStore):
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
return content, metadata or {}
@staticmethod
def _owner_from_context() -> str | None:
"""Soft read of owner_id from contextvar for write paths.
Returns ``None`` (no filter / no stamp) if contextvar is unset,
which is the expected case for background worker writes. HTTP
request writes will have the contextvar set by auth middleware
and get their user_id stamped automatically.
"""
user = get_current_user()
return user.id if user is not None else None
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
"""Write a single event — low-frequency path only.
@@ -68,6 +81,7 @@ class DbRunEventStore(RunEventStore):
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
owner_id = self._owner_from_context()
async with self._sf() as session:
async with session.begin():
# Use FOR UPDATE to serialize seq assignment within a thread.
@@ -78,6 +92,7 @@ class DbRunEventStore(RunEventStore):
row = RunEventRow(
thread_id=thread_id,
run_id=run_id,
owner_id=owner_id,
event_type=event_type,
category=category,
content=db_content,
@@ -91,6 +106,7 @@ class DbRunEventStore(RunEventStore):
async def put_batch(self, events):
if not events:
return []
owner_id = self._owner_from_context()
async with self._sf() as session:
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
@@ -114,6 +130,7 @@ class DbRunEventStore(RunEventStore):
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
owner_id=e.get("owner_id", owner_id),
event_type=e["event_type"],
category=category,
content=db_content,
@@ -125,8 +142,19 @@ class DbRunEventStore(RunEventStore):
rows.append(row)
return [self._row_to_dict(r) for r in rows]
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
async def list_messages(
self,
thread_id,
*,
limit=50,
before_seq=None,
after_seq=None,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
@@ -146,8 +174,19 @@ class DbRunEventStore(RunEventStore):
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
async def list_events(
self,
thread_id,
run_id,
*,
event_types=None,
limit=500,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
if event_types:
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
@@ -155,31 +194,68 @@ class DbRunEventStore(RunEventStore):
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_messages_by_run(self, thread_id, run_id):
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc())
async def list_messages_by_run(
self,
thread_id,
run_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
stmt = stmt.order_by(RunEventRow.seq.asc())
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def count_messages(self, thread_id):
async def count_messages(
self,
thread_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def delete_by_thread(self, thread_id):
async def delete_by_thread(
self,
thread_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
async with self._sf() as session:
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id)
count_conditions = [RunEventRow.thread_id == thread_id]
if resolved_owner_id is not None:
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id))
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count
async def delete_by_run(self, thread_id, run_id):
async def delete_by_run(
self,
thread_id,
run_id,
*,
owner_id: "str | None | _AutoSentinel" = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
async with self._sf() as session:
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
if resolved_owner_id is not None:
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id))
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count
@@ -0,0 +1,148 @@
"""Request-scoped user context for owner-based authorization.
This module holds a :class:`~contextvars.ContextVar` that the gateway's
auth middleware sets after a successful authentication. Repository
methods read the contextvar via a sentinel default parameter, letting
routers stay free of ``owner_id`` boilerplate.
Three-state semantics for the repository ``owner_id`` parameter (the
consumer side of this module lives in ``deerflow.persistence.*``):
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
raise :class:`RuntimeError` if unset.
- Explicit ``str``: use the provided value, overriding contextvar.
- Explicit ``None``: no WHERE clause — used only by migration scripts
and admin CLIs that intentionally bypass isolation.
Dependency direction
--------------------
``persistence`` (lower layer) reads from this module; ``gateway.auth``
(higher layer) writes to it. ``CurrentUser`` is defined here as a
:class:`typing.Protocol` so that ``persistence`` never needs to import
the concrete ``User`` class from ``gateway.auth.models``. Any object
with an ``.id: str`` attribute structurally satisfies the protocol.
Asyncio semantics
-----------------
``ContextVar`` is task-local under asyncio, not thread-local. Each
FastAPI request runs in its own task, so the context is naturally
isolated. ``asyncio.create_task`` and ``asyncio.to_thread`` inherit the
parent task's context, which is typically the intended behaviour; if
a background task must *not* see the foreground user, wrap it with
``contextvars.copy_context()`` to get a clean copy.
"""
from __future__ import annotations
from contextvars import ContextVar, Token
from typing import Final, Protocol, runtime_checkable
@runtime_checkable
class CurrentUser(Protocol):
"""Structural type for the current authenticated user.
Any object with an ``.id: str`` attribute satisfies this protocol.
Concrete implementations live in ``app.gateway.auth.models.User``.
"""
id: str
_current_user: Final[ContextVar["CurrentUser | None"]] = ContextVar(
"deerflow_current_user", default=None
)
def set_current_user(user: CurrentUser) -> Token[CurrentUser | None]:
"""Set the current user for this async task.
Returns a reset token that should be passed to
:func:`reset_current_user` in a ``finally`` block to restore the
previous context.
"""
return _current_user.set(user)
def reset_current_user(token: Token[CurrentUser | None]) -> None:
"""Restore the context to the state captured by ``token``."""
_current_user.reset(token)
def get_current_user() -> CurrentUser | None:
"""Return the current user, or ``None`` if unset.
Safe to call in any context. Used by code paths that can proceed
without a user (e.g. migration scripts, public endpoints).
"""
return _current_user.get()
def require_current_user() -> CurrentUser:
"""Return the current user, or raise :class:`RuntimeError`.
Used by repository code that must not be called outside a
request-authenticated context. The error message is phrased so
that a caller debugging a stack trace can locate the offending
code path.
"""
user = _current_user.get()
if user is None:
raise RuntimeError("repository accessed without user context")
return user
# ---------------------------------------------------------------------------
# Sentinel-based owner_id resolution
# ---------------------------------------------------------------------------
#
# Repository methods accept an ``owner_id`` keyword-only argument that
# defaults to ``AUTO``. The three possible values drive distinct
# behaviours; see the docstring on :func:`resolve_owner_id`.
class _AutoSentinel:
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
_instance: "_AutoSentinel | None" = None
def __new__(cls) -> "_AutoSentinel":
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __repr__(self) -> str:
return "<AUTO>"
AUTO: Final[_AutoSentinel] = _AutoSentinel()
def resolve_owner_id(
value: "str | None | _AutoSentinel",
*,
method_name: str = "repository method",
) -> str | None:
"""Resolve the owner_id parameter passed to a repository method.
Three-state semantics:
- :data:`AUTO` (default): read from contextvar; raise
:class:`RuntimeError` if no user is in context. This is the
common case for request-scoped calls.
- Explicit ``str``: use the provided id verbatim, overriding any
contextvar value. Useful for tests and admin-override flows.
- Explicit ``None``: no filter — the repository should skip the
owner_id WHERE clause entirely. Reserved for migration scripts
and CLI tools that intentionally bypass isolation.
"""
if isinstance(value, _AutoSentinel):
user = _current_user.get()
if user is None:
raise RuntimeError(
f"{method_name} called with owner_id=AUTO but no user context is set; "
"pass an explicit owner_id, set the contextvar via auth middleware, "
"or opt out with owner_id=None for migration/CLI paths."
)
return user.id
return value