mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
feat(auth): enforce owner_id across 2.0-rc persistence layer
Add request-scoped contextvar-based owner filtering to threads_meta,
runs, run_events, and feedback repositories. Router code is unchanged
— isolation is enforced at the storage layer so that any caller that
forgets to pass owner_id still gets filtered results, and new routes
cannot accidentally leak data.
Core infrastructure
-------------------
- deerflow/runtime/user_context.py (new):
- ContextVar[CurrentUser | None] with default None
- runtime_checkable CurrentUser Protocol (structural subtype with .id)
- set/reset/get/require helpers
- AUTO sentinel + resolve_owner_id(value, method_name) for sentinel
three-state resolution: AUTO reads contextvar, explicit str
overrides, explicit None bypasses the filter (for migration/CLI)
Repository changes
------------------
- ThreadMetaRepository: create/get/search/update_*/delete gain
owner_id=AUTO kwarg; read paths filter by owner, writes stamp it,
mutations check ownership before applying
- RunRepository: put/get/list_by_thread/delete gain owner_id=AUTO kwarg
- FeedbackRepository: create/get/list_by_run/list_by_thread/delete
gain owner_id=AUTO kwarg
- DbRunEventStore: list_messages/list_events/list_messages_by_run/
count_messages/delete_by_thread/delete_by_run gain owner_id=AUTO
kwarg. Write paths (put/put_batch) read contextvar softly: when a
request-scoped user is available, owner_id is stamped; background
worker writes without a user context pass None which is valid
(orphan row to be bound by migration)
Schema
------
- persistence/models/run_event.py: RunEventRow.owner_id = Mapped[
str | None] = mapped_column(String(64), nullable=True, index=True)
- No alembic migration needed: 2.0 ships fresh, Base.metadata.create_all
picks up the new column automatically
Middleware
----------
- auth_middleware.py: after cookie check, call get_optional_user_from_
request to load the real User, stamp it into request.state.user AND
the contextvar via set_current_user, reset in a try/finally. Public
paths and unauthenticated requests continue without contextvar, and
@require_auth handles the strict 401 path
Test infrastructure
-------------------
- tests/conftest.py: @pytest.fixture(autouse=True) _auto_user_context
sets a default SimpleNamespace(id="test-user-autouse") on every test
unless marked @pytest.mark.no_auto_user. Keeps existing 20+
persistence tests passing without modification
- pyproject.toml [tool.pytest.ini_options]: register no_auto_user
marker so pytest does not emit warnings for opt-out tests
- tests/test_user_context.py: 6 tests covering three-state semantics,
Protocol duck typing, and require/optional APIs
- tests/test_thread_meta_repo.py: one test updated to pass owner_id=
None explicitly where it was previously relying on the old default
Test results
------------
- test_user_context.py: 6 passed
- test_auth*.py + test_langgraph_auth.py + test_ensure_admin.py: 127
- test_run_event_store / test_run_repository / test_thread_meta_repo
/ test_feedback: 92 passed
- Full backend suite: 1905 passed, 2 failed (both @requires_llm flaky
integration tests unrelated to auth), 1 skipped
This commit is contained in:
@@ -1,6 +1,11 @@
|
|||||||
"""Global authentication middleware — fail-closed safety net.
|
"""Global authentication middleware — fail-closed safety net.
|
||||||
|
|
||||||
Rejects unauthenticated requests to non-public paths with 401.
|
Rejects unauthenticated requests to non-public paths with 401. When a
|
||||||
|
request passes the cookie check, resolves the JWT payload to a real
|
||||||
|
``User`` object and stamps it into both ``request.state.user`` and the
|
||||||
|
``deerflow.runtime.user_context`` contextvar so that repository-layer
|
||||||
|
owner filtering works automatically via the sentinel pattern.
|
||||||
|
|
||||||
Fine-grained permission checks remain in authz.py decorators.
|
Fine-grained permission checks remain in authz.py decorators.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -12,6 +17,7 @@ from starlette.responses import JSONResponse
|
|||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
from app.gateway.auth.errors import AuthErrorCode
|
from app.gateway.auth.errors import AuthErrorCode
|
||||||
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
|
|
||||||
# Paths that never require authentication.
|
# Paths that never require authentication.
|
||||||
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
_PUBLIC_PATH_PREFIXES: tuple[str, ...] = (
|
||||||
@@ -68,4 +74,22 @@ class AuthMiddleware(BaseHTTPMiddleware):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
return await call_next(request)
|
# Resolve the full user now so repository-layer owner filters
|
||||||
|
# can read from the contextvar. We use the "optional" flavour so
|
||||||
|
# middleware never raises on bad tokens — the cookie-presence
|
||||||
|
# check above plus the @require_auth decorator provide the
|
||||||
|
# strict gates. A stale/invalid token yields user=None here;
|
||||||
|
# the request continues without a contextvar, and any protected
|
||||||
|
# endpoint will still be rejected by @require_auth.
|
||||||
|
from app.gateway.deps import get_optional_user_from_request
|
||||||
|
|
||||||
|
user = await get_optional_user_from_request(request)
|
||||||
|
if user is None:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
request.state.user = user
|
||||||
|
token = set_current_user(user)
|
||||||
|
try:
|
||||||
|
return await call_next(request)
|
||||||
|
finally:
|
||||||
|
reset_current_user(token)
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from sqlalchemy import case, func, select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from deerflow.persistence.feedback.model import FeedbackRow
|
from deerflow.persistence.feedback.model import FeedbackRow
|
||||||
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||||
|
|
||||||
|
|
||||||
class FeedbackRepository:
|
class FeedbackRepository:
|
||||||
@@ -32,18 +33,19 @@ class FeedbackRepository:
|
|||||||
run_id: str,
|
run_id: str,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
rating: int,
|
rating: int,
|
||||||
owner_id: str | None = None,
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
comment: str | None = None,
|
comment: str | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Create a feedback record. rating must be +1 or -1."""
|
"""Create a feedback record. rating must be +1 or -1."""
|
||||||
if rating not in (1, -1):
|
if rating not in (1, -1):
|
||||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
||||||
row = FeedbackRow(
|
row = FeedbackRow(
|
||||||
feedback_id=str(uuid.uuid4()),
|
feedback_id=str(uuid.uuid4()),
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
owner_id=owner_id,
|
owner_id=resolved_owner_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
rating=rating,
|
rating=rating,
|
||||||
comment=comment,
|
comment=comment,
|
||||||
@@ -55,27 +57,66 @@ class FeedbackRepository:
|
|||||||
await session.refresh(row)
|
await session.refresh(row)
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
async def get(self, feedback_id: str) -> dict | None:
|
async def get(
|
||||||
async with self._sf() as session:
|
self,
|
||||||
row = await session.get(FeedbackRow, feedback_id)
|
feedback_id: str,
|
||||||
return self._row_to_dict(row) if row else None
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
async def list_by_run(self, thread_id: str, run_id: str, *, limit: int = 100) -> list[dict]:
|
) -> dict | None:
|
||||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
|
||||||
|
|
||||||
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict]:
|
|
||||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
|
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
|
||||||
|
|
||||||
async def delete(self, feedback_id: str) -> bool:
|
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(FeedbackRow, feedback_id)
|
row = await session.get(FeedbackRow, feedback_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
|
return None
|
||||||
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
|
return None
|
||||||
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
|
async def list_by_run(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
) -> list[dict]:
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
||||||
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||||
|
if resolved_owner_id is not None:
|
||||||
|
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||||
|
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
|
async def list_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 100,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
) -> list[dict]:
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
||||||
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||||
|
if resolved_owner_id is not None:
|
||||||
|
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
||||||
|
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
|
async def delete(
|
||||||
|
self,
|
||||||
|
feedback_id: str,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
) -> bool:
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
||||||
|
async with self._sf() as session:
|
||||||
|
row = await session.get(FeedbackRow, feedback_id)
|
||||||
|
if row is None:
|
||||||
|
return False
|
||||||
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
return False
|
return False
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ class RunEventRow(Base):
|
|||||||
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
thread_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
run_id: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
# Owner of the conversation this event belongs to. Nullable for data
|
||||||
|
# created before auth was introduced; populated by auth middleware on
|
||||||
|
# new writes and by the boot-time orphan migration on existing rows.
|
||||||
|
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||||
# "message" | "trace" | "lifecycle"
|
# "message" | "trace" | "lifecycle"
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.run.model import RunRow
|
from deerflow.persistence.run.model import RunRow
|
||||||
from deerflow.runtime.runs.store.base import RunStore
|
from deerflow.runtime.runs.store.base import RunStore
|
||||||
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||||
|
|
||||||
|
|
||||||
class RunRepository(RunStore):
|
class RunRepository(RunStore):
|
||||||
@@ -68,7 +69,7 @@ class RunRepository(RunStore):
|
|||||||
*,
|
*,
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=None,
|
assistant_id=None,
|
||||||
owner_id=None,
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
status="pending",
|
status="pending",
|
||||||
multitask_strategy="reject",
|
multitask_strategy="reject",
|
||||||
metadata=None,
|
metadata=None,
|
||||||
@@ -77,12 +78,13 @@ class RunRepository(RunStore):
|
|||||||
created_at=None,
|
created_at=None,
|
||||||
follow_up_to_run_id=None,
|
follow_up_to_run_id=None,
|
||||||
):
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
row = RunRow(
|
row = RunRow(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
owner_id=owner_id,
|
owner_id=resolved_owner_id,
|
||||||
status=status,
|
status=status,
|
||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata_json=self._safe_json(metadata) or {},
|
metadata_json=self._safe_json(metadata) or {},
|
||||||
@@ -96,15 +98,32 @@ class RunRepository(RunStore):
|
|||||||
session.add(row)
|
session.add(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def get(self, run_id):
|
async def get(
|
||||||
|
self,
|
||||||
|
run_id,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(RunRow, run_id)
|
row = await session.get(RunRow, run_id)
|
||||||
return self._row_to_dict(row) if row else None
|
if row is None:
|
||||||
|
return None
|
||||||
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
|
return None
|
||||||
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
async def list_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
limit=100,
|
||||||
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
||||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||||
if owner_id is not None:
|
if resolved_owner_id is not None:
|
||||||
stmt = stmt.where(RunRow.owner_id == owner_id)
|
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
||||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -118,12 +137,21 @@ class RunRepository(RunStore):
|
|||||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def delete(self, run_id):
|
async def delete(
|
||||||
|
self,
|
||||||
|
run_id,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(RunRow, run_id)
|
row = await session.get(RunRow, run_id)
|
||||||
if row is not None:
|
if row is None:
|
||||||
await session.delete(row)
|
return
|
||||||
await session.commit()
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
|
return
|
||||||
|
await session.delete(row)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def list_pending(self, *, before=None):
|
async def list_pending(self, *, before=None):
|
||||||
if before is None:
|
if before is None:
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||||
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
||||||
|
|
||||||
|
|
||||||
class ThreadMetaRepository(ThreadMetaStore):
|
class ThreadMetaRepository(ThreadMetaStore):
|
||||||
@@ -31,15 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
owner_id: str | None = None,
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
display_name: str | None = None,
|
display_name: str | None = None,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
|
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
||||||
|
# creates an orphan row (used by migration scripts).
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
row = ThreadMetaRow(
|
row = ThreadMetaRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
owner_id=owner_id,
|
owner_id=resolved_owner_id,
|
||||||
display_name=display_name,
|
display_name=display_name,
|
||||||
metadata_json=metadata or {},
|
metadata_json=metadata or {},
|
||||||
created_at=now,
|
created_at=now,
|
||||||
@@ -51,10 +55,21 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
await session.refresh(row)
|
await session.refresh(row)
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
async def get(self, thread_id: str) -> dict | None:
|
async def get(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
) -> dict | None:
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
return self._row_to_dict(row) if row else None
|
if row is None:
|
||||||
|
return None
|
||||||
|
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
||||||
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
|
return None
|
||||||
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
||||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
||||||
@@ -83,9 +98,17 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Search threads with optional metadata and status filters."""
|
"""Search threads with optional metadata and status filters.
|
||||||
|
|
||||||
|
Owner filter is enforced by default: caller must be in a user
|
||||||
|
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
||||||
|
"""
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
||||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||||
|
if resolved_owner_id is not None:
|
||||||
|
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
||||||
if status:
|
if status:
|
||||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||||
|
|
||||||
@@ -105,36 +128,80 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
||||||
|
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||||
|
if resolved_owner_id is None:
|
||||||
|
return True # explicit bypass
|
||||||
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
|
return row is not None and row.owner_id == resolved_owner_id
|
||||||
|
|
||||||
|
async def update_display_name(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
display_name: str,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
) -> None:
|
||||||
"""Update the display_name (title) for a thread."""
|
"""Update the display_name (title) for a thread."""
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
|
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||||
|
return
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def update_status(self, thread_id: str, status: str) -> None:
|
async def update_status(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
status: str,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
) -> None:
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
|
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
||||||
|
return
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def update_metadata(self, thread_id: str, metadata: dict) -> None:
|
async def update_metadata(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
metadata: dict,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
) -> None:
|
||||||
"""Merge ``metadata`` into ``metadata_json``.
|
"""Merge ``metadata`` into ``metadata_json``.
|
||||||
|
|
||||||
Read-modify-write inside a single session/transaction so concurrent
|
Read-modify-write inside a single session/transaction so concurrent
|
||||||
callers see consistent state. No-op if the row does not exist.
|
callers see consistent state. No-op if the row does not exist or
|
||||||
|
the owner_id check fails.
|
||||||
"""
|
"""
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
|
return
|
||||||
merged = dict(row.metadata_json or {})
|
merged = dict(row.metadata_json or {})
|
||||||
merged.update(metadata)
|
merged.update(metadata)
|
||||||
row.metadata_json = merged
|
row.metadata_json = merged
|
||||||
row.updated_at = datetime.now(UTC)
|
row.updated_at = datetime.now(UTC)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def delete(self, thread_id: str) -> None:
|
async def delete(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
) -> None:
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is not None:
|
if row is None:
|
||||||
await session.delete(row)
|
return
|
||||||
await session.commit()
|
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
||||||
|
return
|
||||||
|
await session.delete(row)
|
||||||
|
await session.commit()
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.models.run_event import RunEventRow
|
from deerflow.persistence.models.run_event import RunEventRow
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -53,6 +54,18 @@ class DbRunEventStore(RunEventStore):
|
|||||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||||
return content, metadata or {}
|
return content, metadata or {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _owner_from_context() -> str | None:
|
||||||
|
"""Soft read of owner_id from contextvar for write paths.
|
||||||
|
|
||||||
|
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||||
|
which is the expected case for background worker writes. HTTP
|
||||||
|
request writes will have the contextvar set by auth middleware
|
||||||
|
and get their user_id stamped automatically.
|
||||||
|
"""
|
||||||
|
user = get_current_user()
|
||||||
|
return user.id if user is not None else None
|
||||||
|
|
||||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||||
"""Write a single event — low-frequency path only.
|
"""Write a single event — low-frequency path only.
|
||||||
|
|
||||||
@@ -68,6 +81,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||||
else:
|
else:
|
||||||
db_content = content
|
db_content = content
|
||||||
|
owner_id = self._owner_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||||
@@ -78,6 +92,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
owner_id=owner_id,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
category=category,
|
category=category,
|
||||||
content=db_content,
|
content=db_content,
|
||||||
@@ -91,6 +106,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
async def put_batch(self, events):
|
async def put_batch(self, events):
|
||||||
if not events:
|
if not events:
|
||||||
return []
|
return []
|
||||||
|
owner_id = self._owner_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||||
@@ -114,6 +130,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=e["thread_id"],
|
thread_id=e["thread_id"],
|
||||||
run_id=e["run_id"],
|
run_id=e["run_id"],
|
||||||
|
owner_id=e.get("owner_id", owner_id),
|
||||||
event_type=e["event_type"],
|
event_type=e["event_type"],
|
||||||
category=category,
|
category=category,
|
||||||
content=db_content,
|
content=db_content,
|
||||||
@@ -125,8 +142,19 @@ class DbRunEventStore(RunEventStore):
|
|||||||
rows.append(row)
|
rows.append(row)
|
||||||
return [self._row_to_dict(r) for r in rows]
|
return [self._row_to_dict(r) for r in rows]
|
||||||
|
|
||||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
async def list_messages(
|
||||||
|
self,
|
||||||
|
thread_id,
|
||||||
|
*,
|
||||||
|
limit=50,
|
||||||
|
before_seq=None,
|
||||||
|
after_seq=None,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||||
|
if resolved_owner_id is not None:
|
||||||
|
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||||
if before_seq is not None:
|
if before_seq is not None:
|
||||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||||
if after_seq is not None:
|
if after_seq is not None:
|
||||||
@@ -146,8 +174,19 @@ class DbRunEventStore(RunEventStore):
|
|||||||
rows = list(result.scalars())
|
rows = list(result.scalars())
|
||||||
return [self._row_to_dict(r) for r in reversed(rows)]
|
return [self._row_to_dict(r) for r in reversed(rows)]
|
||||||
|
|
||||||
async def list_events(self, thread_id, run_id, *, event_types=None, limit=500):
|
async def list_events(
|
||||||
|
self,
|
||||||
|
thread_id,
|
||||||
|
run_id,
|
||||||
|
*,
|
||||||
|
event_types=None,
|
||||||
|
limit=500,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||||
|
if resolved_owner_id is not None:
|
||||||
|
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||||
if event_types:
|
if event_types:
|
||||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||||
@@ -155,31 +194,68 @@ class DbRunEventStore(RunEventStore):
|
|||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
async def list_messages_by_run(self, thread_id, run_id):
|
async def list_messages_by_run(
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message").order_by(RunEventRow.seq.asc())
|
self,
|
||||||
|
thread_id,
|
||||||
|
run_id,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||||
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
||||||
|
if resolved_owner_id is not None:
|
||||||
|
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||||
|
stmt = stmt.order_by(RunEventRow.seq.asc())
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
async def count_messages(self, thread_id):
|
async def count_messages(
|
||||||
|
self,
|
||||||
|
thread_id,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
||||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||||
|
if resolved_owner_id is not None:
|
||||||
|
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
return await session.scalar(stmt) or 0
|
return await session.scalar(stmt) or 0
|
||||||
|
|
||||||
async def delete_by_thread(self, thread_id):
|
async def delete_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id)
|
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||||
|
if resolved_owner_id is not None:
|
||||||
|
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||||
|
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||||
count = await session.scalar(count_stmt) or 0
|
count = await session.scalar(count_stmt) or 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id))
|
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return count
|
return count
|
||||||
|
|
||||||
async def delete_by_run(self, thread_id, run_id):
|
async def delete_by_run(
|
||||||
|
self,
|
||||||
|
thread_id,
|
||||||
|
run_id,
|
||||||
|
*,
|
||||||
|
owner_id: "str | None | _AutoSentinel" = AUTO,
|
||||||
|
):
|
||||||
|
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
count_stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||||
|
if resolved_owner_id is not None:
|
||||||
|
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
||||||
|
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||||
count = await session.scalar(count_stmt) or 0
|
count = await session.scalar(count_stmt) or 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
await session.execute(delete(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id))
|
await session.execute(delete(RunEventRow).where(*count_conditions))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
return count
|
return count
|
||||||
|
|||||||
@@ -0,0 +1,148 @@
|
|||||||
|
"""Request-scoped user context for owner-based authorization.
|
||||||
|
|
||||||
|
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||||
|
auth middleware sets after a successful authentication. Repository
|
||||||
|
methods read the contextvar via a sentinel default parameter, letting
|
||||||
|
routers stay free of ``owner_id`` boilerplate.
|
||||||
|
|
||||||
|
Three-state semantics for the repository ``owner_id`` parameter (the
|
||||||
|
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||||
|
|
||||||
|
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||||
|
raise :class:`RuntimeError` if unset.
|
||||||
|
- Explicit ``str``: use the provided value, overriding contextvar.
|
||||||
|
- Explicit ``None``: no WHERE clause — used only by migration scripts
|
||||||
|
and admin CLIs that intentionally bypass isolation.
|
||||||
|
|
||||||
|
Dependency direction
|
||||||
|
--------------------
|
||||||
|
``persistence`` (lower layer) reads from this module; ``gateway.auth``
|
||||||
|
(higher layer) writes to it. ``CurrentUser`` is defined here as a
|
||||||
|
:class:`typing.Protocol` so that ``persistence`` never needs to import
|
||||||
|
the concrete ``User`` class from ``gateway.auth.models``. Any object
|
||||||
|
with an ``.id: str`` attribute structurally satisfies the protocol.
|
||||||
|
|
||||||
|
Asyncio semantics
|
||||||
|
-----------------
|
||||||
|
``ContextVar`` is task-local under asyncio, not thread-local. Each
|
||||||
|
FastAPI request runs in its own task, so the context is naturally
|
||||||
|
isolated. ``asyncio.create_task`` and ``asyncio.to_thread`` inherit the
|
||||||
|
parent task's context, which is typically the intended behaviour; if
|
||||||
|
a background task must *not* see the foreground user, wrap it with
|
||||||
|
``contextvars.copy_context()`` to get a clean copy.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from contextvars import ContextVar, Token
|
||||||
|
from typing import Final, Protocol, runtime_checkable
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class CurrentUser(Protocol):
|
||||||
|
"""Structural type for the current authenticated user.
|
||||||
|
|
||||||
|
Any object with an ``.id: str`` attribute satisfies this protocol.
|
||||||
|
Concrete implementations live in ``app.gateway.auth.models.User``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
|
_current_user: Final[ContextVar["CurrentUser | None"]] = ContextVar(
|
||||||
|
"deerflow_current_user", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_current_user(user: CurrentUser) -> Token[CurrentUser | None]:
|
||||||
|
"""Set the current user for this async task.
|
||||||
|
|
||||||
|
Returns a reset token that should be passed to
|
||||||
|
:func:`reset_current_user` in a ``finally`` block to restore the
|
||||||
|
previous context.
|
||||||
|
"""
|
||||||
|
return _current_user.set(user)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_current_user(token: Token[CurrentUser | None]) -> None:
|
||||||
|
"""Restore the context to the state captured by ``token``."""
|
||||||
|
_current_user.reset(token)
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user() -> CurrentUser | None:
|
||||||
|
"""Return the current user, or ``None`` if unset.
|
||||||
|
|
||||||
|
Safe to call in any context. Used by code paths that can proceed
|
||||||
|
without a user (e.g. migration scripts, public endpoints).
|
||||||
|
"""
|
||||||
|
return _current_user.get()
|
||||||
|
|
||||||
|
|
||||||
|
def require_current_user() -> CurrentUser:
|
||||||
|
"""Return the current user, or raise :class:`RuntimeError`.
|
||||||
|
|
||||||
|
Used by repository code that must not be called outside a
|
||||||
|
request-authenticated context. The error message is phrased so
|
||||||
|
that a caller debugging a stack trace can locate the offending
|
||||||
|
code path.
|
||||||
|
"""
|
||||||
|
user = _current_user.get()
|
||||||
|
if user is None:
|
||||||
|
raise RuntimeError("repository accessed without user context")
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sentinel-based owner_id resolution
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Repository methods accept an ``owner_id`` keyword-only argument that
|
||||||
|
# defaults to ``AUTO``. The three possible values drive distinct
|
||||||
|
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
||||||
|
|
||||||
|
|
||||||
|
class _AutoSentinel:
|
||||||
|
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
||||||
|
|
||||||
|
_instance: "_AutoSentinel | None" = None
|
||||||
|
|
||||||
|
def __new__(cls) -> "_AutoSentinel":
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return "<AUTO>"
|
||||||
|
|
||||||
|
|
||||||
|
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_owner_id(
|
||||||
|
value: "str | None | _AutoSentinel",
|
||||||
|
*,
|
||||||
|
method_name: str = "repository method",
|
||||||
|
) -> str | None:
|
||||||
|
"""Resolve the owner_id parameter passed to a repository method.
|
||||||
|
|
||||||
|
Three-state semantics:
|
||||||
|
|
||||||
|
- :data:`AUTO` (default): read from contextvar; raise
|
||||||
|
:class:`RuntimeError` if no user is in context. This is the
|
||||||
|
common case for request-scoped calls.
|
||||||
|
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||||
|
contextvar value. Useful for tests and admin-override flows.
|
||||||
|
- Explicit ``None``: no filter — the repository should skip the
|
||||||
|
owner_id WHERE clause entirely. Reserved for migration scripts
|
||||||
|
and CLI tools that intentionally bypass isolation.
|
||||||
|
"""
|
||||||
|
if isinstance(value, _AutoSentinel):
|
||||||
|
user = _current_user.get()
|
||||||
|
if user is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{method_name} called with owner_id=AUTO but no user context is set; "
|
||||||
|
"pass an explicit owner_id, set the contextvar via auth middleware, "
|
||||||
|
"or opt out with owner_id=None for migration/CLI paths."
|
||||||
|
)
|
||||||
|
return user.id
|
||||||
|
return value
|
||||||
@@ -30,6 +30,11 @@ postgres = [
|
|||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
dev = ["pytest>=8.0.0", "ruff>=0.14.11"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
markers = [
|
||||||
|
"no_auto_user: disable the conftest autouse contextvar fixture for this test",
|
||||||
|
]
|
||||||
|
|
||||||
[tool.uv.workspace]
|
[tool.uv.workspace]
|
||||||
members = ["packages/harness"]
|
members = ["packages/harness"]
|
||||||
|
|
||||||
|
|||||||
@@ -6,8 +6,11 @@ issues when unit-testing lightweight config/registry code in isolation.
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
# Make 'app' and 'deerflow' importable from any working directory
|
# Make 'app' and 'deerflow' importable from any working directory
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
@@ -31,3 +34,44 @@ _executor_mock.MAX_CONCURRENT_SUBAGENTS = 3
|
|||||||
_executor_mock.get_background_task_result = MagicMock()
|
_executor_mock.get_background_task_result = MagicMock()
|
||||||
|
|
||||||
sys.modules["deerflow.subagents.executor"] = _executor_mock
|
sys.modules["deerflow.subagents.executor"] = _executor_mock
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auto-set user context for every test unless marked no_auto_user
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
#
|
||||||
|
# Repository methods read ``owner_id`` from a contextvar by default
|
||||||
|
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||||
|
# pre-existing persistence test would raise RuntimeError because the
|
||||||
|
# contextvar is unset. The fixture sets a default test user on every
|
||||||
|
# test; tests that explicitly want to verify behaviour *without* a user
|
||||||
|
# context should mark themselves ``@pytest.mark.no_auto_user``.
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _auto_user_context(request):
|
||||||
|
"""Inject a default ``test-user-autouse`` into the contextvar.
|
||||||
|
|
||||||
|
Opt-out via ``@pytest.mark.no_auto_user``. Uses lazy import so that
|
||||||
|
tests which don't touch the persistence layer never pay the cost
|
||||||
|
of importing runtime.user_context.
|
||||||
|
"""
|
||||||
|
if request.node.get_closest_marker("no_auto_user"):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deerflow.runtime.user_context import (
|
||||||
|
reset_current_user,
|
||||||
|
set_current_user,
|
||||||
|
)
|
||||||
|
except ImportError:
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
user = SimpleNamespace(id="test-user-autouse", email="test@local")
|
||||||
|
token = set_current_user(user)
|
||||||
|
try:
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
reset_current_user(token)
|
||||||
|
|||||||
@@ -104,7 +104,9 @@ class TestThreadMetaRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1") # owner_id=None
|
# Explicit owner_id=None to bypass the new AUTO default that
|
||||||
|
# would otherwise pick up the test user from the autouse fixture.
|
||||||
|
await repo.create("t1", owner_id=None)
|
||||||
assert await repo.check_access("t1", "anyone") is True
|
assert await repo.check_access("t1", "anyone") is True
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,69 @@
|
|||||||
|
"""Tests for runtime.user_context — contextvar three-state semantics.
|
||||||
|
|
||||||
|
These tests opt out of the autouse contextvar fixture (added in
|
||||||
|
commit 6) because they explicitly test the cases where the contextvar
|
||||||
|
is set or unset.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from deerflow.runtime.user_context import (
|
||||||
|
CurrentUser,
|
||||||
|
get_current_user,
|
||||||
|
require_current_user,
|
||||||
|
reset_current_user,
|
||||||
|
set_current_user,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
def test_default_is_none():
|
||||||
|
"""Before any set, contextvar returns None."""
|
||||||
|
assert get_current_user() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
def test_set_and_reset_roundtrip():
|
||||||
|
"""set_current_user returns a token that reset restores."""
|
||||||
|
user = SimpleNamespace(id="user-1")
|
||||||
|
token = set_current_user(user)
|
||||||
|
try:
|
||||||
|
assert get_current_user() is user
|
||||||
|
finally:
|
||||||
|
reset_current_user(token)
|
||||||
|
assert get_current_user() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
def test_require_current_user_raises_when_unset():
|
||||||
|
"""require_current_user raises RuntimeError if contextvar is unset."""
|
||||||
|
assert get_current_user() is None
|
||||||
|
with pytest.raises(RuntimeError, match="without user context"):
|
||||||
|
require_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
def test_require_current_user_returns_user_when_set():
|
||||||
|
"""require_current_user returns the user when contextvar is set."""
|
||||||
|
user = SimpleNamespace(id="user-2")
|
||||||
|
token = set_current_user(user)
|
||||||
|
try:
|
||||||
|
assert require_current_user() is user
|
||||||
|
finally:
|
||||||
|
reset_current_user(token)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
def test_protocol_accepts_duck_typed():
|
||||||
|
"""CurrentUser is a runtime_checkable Protocol matching any .id-bearing object."""
|
||||||
|
user = SimpleNamespace(id="user-3")
|
||||||
|
assert isinstance(user, CurrentUser)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.no_auto_user
|
||||||
|
def test_protocol_rejects_no_id():
|
||||||
|
"""Objects without .id do not satisfy CurrentUser Protocol."""
|
||||||
|
not_a_user = SimpleNamespace(email="no-id@example.com")
|
||||||
|
assert not isinstance(not_a_user, CurrentUser)
|
||||||
Reference in New Issue
Block a user