refactor(persistence): rename owner_id to user_id and thread_meta_repo to thread_store

Rename owner_id to user_id across all persistence models, repositories,
stores, routers, and tests for clearer semantics. Rename thread_meta_repo
to thread_store for consistency with run_store/run_event_store naming.
Add ThreadMetaStore return type annotation to get_thread_store().

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-10 15:05:10 +08:00
parent 03952eca53
commit 8da1903168
32 changed files with 256 additions and 276 deletions
@@ -16,7 +16,7 @@ class FeedbackRow(Base):
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
message_id: Mapped[str | None] = mapped_column(String(64))
# message_id is an optional RunEventStore event identifier —
# allows feedback to target a specific message or the entire run
@@ -12,7 +12,7 @@ from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.feedback.model import FeedbackRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
class FeedbackRepository:
@@ -33,19 +33,19 @@ class FeedbackRepository:
run_id: str,
thread_id: str,
rating: int,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
message_id: str | None = None,
comment: str | None = None,
) -> dict:
"""Create a feedback record. rating must be +1 or -1."""
if rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {rating}")
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
row = FeedbackRow(
feedback_id=str(uuid.uuid4()),
run_id=run_id,
thread_id=thread_id,
owner_id=resolved_owner_id,
user_id=resolved_user_id,
message_id=message_id,
rating=rating,
comment=comment,
@@ -61,14 +61,14 @@ class FeedbackRepository:
self,
feedback_id: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
if resolved_user_id is not None and row.user_id != resolved_user_id:
return None
return self._row_to_dict(row)
@@ -78,12 +78,12 @@ class FeedbackRepository:
run_id: str,
*,
limit: int = 100,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
if resolved_owner_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
@@ -94,12 +94,12 @@ class FeedbackRepository:
thread_id: str,
*,
limit: int = 100,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
if resolved_owner_id is not None:
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
@@ -109,14 +109,14 @@ class FeedbackRepository:
self,
feedback_id: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> bool:
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
async with self._sf() as session:
row = await session.get(FeedbackRow, feedback_id)
if row is None:
return False
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
if resolved_user_id is not None and row.user_id != resolved_user_id:
return False
await session.delete(row)
await session.commit()
@@ -19,7 +19,7 @@ class RunEventRow(Base):
# Owner of the conversation this event belongs to. Nullable for data
# created before auth was introduced; populated by auth middleware on
# new writes and by the boot-time orphan migration on existing rows.
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
category: Mapped[str] = mapped_column(String(16), nullable=False)
# "message" | "trace" | "lifecycle"
@@ -16,7 +16,7 @@ class RunRow(Base):
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
assistant_id: Mapped[str | None] = mapped_column(String(128))
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
status: Mapped[str] = mapped_column(String(20), default="pending")
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
@@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.run.model import RunRow
from deerflow.runtime.runs.store.base import RunStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
class RunRepository(RunStore):
@@ -69,7 +69,7 @@ class RunRepository(RunStore):
*,
thread_id,
assistant_id=None,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
status="pending",
multitask_strategy="reject",
metadata=None,
@@ -78,13 +78,13 @@ class RunRepository(RunStore):
created_at=None,
follow_up_to_run_id=None,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
now = datetime.now(UTC)
row = RunRow(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=resolved_owner_id,
user_id=resolved_user_id,
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
@@ -102,14 +102,14 @@ class RunRepository(RunStore):
self,
run_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
if row is None:
return None
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
if resolved_user_id is not None and row.user_id != resolved_user_id:
return None
return self._row_to_dict(row)
@@ -117,13 +117,13 @@ class RunRepository(RunStore):
self,
thread_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
limit=100,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
if resolved_owner_id is not None:
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
stmt = stmt.where(RunRow.user_id == resolved_user_id)
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
@@ -141,14 +141,14 @@ class RunRepository(RunStore):
self,
run_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
async with self._sf() as session:
row = await session.get(RunRow, run_id)
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
if resolved_user_id is not None and row.user_id != resolved_user_id:
return
await session.delete(row)
await session.commit()
@@ -15,7 +15,7 @@ class ThreadMetaRow(Base):
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
display_name: Mapped[str | None] = mapped_column(String(256))
status: Mapped[str] = mapped_column(String(20), default="idle")
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.thread_meta.base import ThreadMetaStore
from deerflow.persistence.thread_meta.model import ThreadMetaRow
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
class ThreadMetaRepository(ThreadMetaStore):
@@ -32,18 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
thread_id: str,
*,
assistant_id: str | None = None,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
display_name: str | None = None,
metadata: dict | None = None,
) -> dict:
# Auto-resolve owner_id from contextvar when AUTO; explicit None
# Auto-resolve user_id from contextvar when AUTO; explicit None
# creates an orphan row (used by migration scripts).
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
now = datetime.now(UTC)
row = ThreadMetaRow(
thread_id=thread_id,
assistant_id=assistant_id,
owner_id=resolved_owner_id,
user_id=resolved_user_id,
display_name=display_name,
metadata_json=metadata or {},
created_at=now,
@@ -59,40 +59,34 @@ class ThreadMetaRepository(ThreadMetaStore):
self,
thread_id: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> dict | None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return None
# Enforce owner filter unless explicitly bypassed (owner_id=None).
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
# Enforce owner filter unless explicitly bypassed (user_id=None).
if resolved_user_id is not None and row.user_id != resolved_user_id:
return None
return self._row_to_dict(row)
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``owner_id`` has access to ``thread_id``.
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``user_id`` has access to ``thread_id``.
Two modes — one row, two distinct semantics depending on what
the caller is about to do:
- ``require_existing=False`` (default, permissive):
Returns True for: row missing (untracked legacy thread),
``row.owner_id`` is None (shared / pre-auth data),
or ``row.owner_id == owner_id``. Use for **read-style**
``row.user_id`` is None (shared / pre-auth data),
or ``row.user_id == user_id``. Use for **read-style**
decorators where treating an untracked thread as accessible
preserves backward-compat.
- ``require_existing=True`` (strict):
Returns True **only** when the row exists AND
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
(``row.user_id == user_id`` OR ``row.user_id is None``).
Use for **destructive / mutating** decorators (DELETE, PATCH,
state-update) so a thread that has *already been deleted*
cannot be re-targeted by any caller — closing the
@@ -103,9 +97,9 @@ class ThreadMetaRepository(ThreadMetaStore):
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return not require_existing
if row.owner_id is None:
if row.user_id is None:
return True
return row.owner_id == owner_id
return row.user_id == user_id
async def search(
self,
@@ -114,17 +108,17 @@ class ThreadMetaRepository(ThreadMetaStore):
status: str | None = None,
limit: int = 100,
offset: int = 0,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> list[dict]:
"""Search threads with optional metadata and status filters.
Owner filter is enforced by default: caller must be in a user
context. Pass ``owner_id=None`` to bypass (migration/CLI).
context. Pass ``user_id=None`` to bypass (migration/CLI).
"""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
if resolved_owner_id is not None:
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
if status:
stmt = stmt.where(ThreadMetaRow.status == status)
@@ -144,24 +138,24 @@ class ThreadMetaRepository(ThreadMetaStore):
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
"""Return True if the row exists and is owned (or filter bypassed)."""
if resolved_owner_id is None:
if resolved_user_id is None:
return True # explicit bypass
row = await session.get(ThreadMetaRow, thread_id)
return row is not None and row.owner_id == resolved_owner_id
return row is not None and row.user_id == resolved_user_id
async def update_display_name(
self,
thread_id: str,
display_name: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
"""Update the display_name (title) for a thread."""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id):
if not await self._check_ownership(session, thread_id, resolved_user_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
await session.commit()
@@ -171,11 +165,11 @@ class ThreadMetaRepository(ThreadMetaStore):
thread_id: str,
status: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_owner_id):
if not await self._check_ownership(session, thread_id, resolved_user_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
await session.commit()
@@ -185,20 +179,20 @@ class ThreadMetaRepository(ThreadMetaStore):
thread_id: str,
metadata: dict,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
"""Merge ``metadata`` into ``metadata_json``.
Read-modify-write inside a single session/transaction so concurrent
callers see consistent state. No-op if the row does not exist or
the owner_id check fails.
the user_id check fails.
"""
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
if resolved_user_id is not None and row.user_id != resolved_user_id:
return
merged = dict(row.metadata_json or {})
merged.update(metadata)
@@ -210,14 +204,14 @@ class ThreadMetaRepository(ThreadMetaStore):
self,
thread_id: str,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
async with self._sf() as session:
row = await session.get(ThreadMetaRow, thread_id)
if row is None:
return
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
if resolved_user_id is not None and row.user_id != resolved_user_id:
return
await session.delete(row)
await session.commit()
@@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config
from deerflow.runtime.checkpointer.provider import (
POSTGRES_CONN_REQUIRED,
POSTGRES_INSTALL,
SQLITE_INSTALL,
)
from deerflow.config.app_config import get_app_config
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__)
@@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
logger = logging.getLogger(__name__)
@@ -55,8 +55,8 @@ class DbRunEventStore(RunEventStore):
return content, metadata or {}
@staticmethod
def _owner_from_context() -> str | None:
"""Soft read of owner_id from contextvar for write paths.
def _user_id_from_context() -> str | None:
"""Soft read of user_id from contextvar for write paths.
Returns ``None`` (no filter / no stamp) if contextvar is unset,
which is the expected case for background worker writes. HTTP
@@ -81,7 +81,7 @@ class DbRunEventStore(RunEventStore):
metadata = {**(metadata or {}), "content_is_dict": True}
else:
db_content = content
owner_id = self._owner_from_context()
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# Use FOR UPDATE to serialize seq assignment within a thread.
@@ -92,7 +92,7 @@ class DbRunEventStore(RunEventStore):
row = RunEventRow(
thread_id=thread_id,
run_id=run_id,
owner_id=owner_id,
user_id=user_id,
event_type=event_type,
category=category,
content=db_content,
@@ -106,7 +106,7 @@ class DbRunEventStore(RunEventStore):
async def put_batch(self, events):
if not events:
return []
owner_id = self._owner_from_context()
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
@@ -130,7 +130,7 @@ class DbRunEventStore(RunEventStore):
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
owner_id=e.get("owner_id", owner_id),
user_id=e.get("user_id", user_id),
event_type=e["event_type"],
category=category,
content=db_content,
@@ -149,12 +149,12 @@ class DbRunEventStore(RunEventStore):
limit=50,
before_seq=None,
after_seq=None,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
@@ -181,12 +181,12 @@ class DbRunEventStore(RunEventStore):
*,
event_types=None,
limit=500,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if event_types:
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
@@ -199,12 +199,12 @@ class DbRunEventStore(RunEventStore):
thread_id,
run_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
stmt = stmt.order_by(RunEventRow.seq.asc())
async with self._sf() as session:
result = await session.execute(stmt)
@@ -214,12 +214,12 @@ class DbRunEventStore(RunEventStore):
self,
thread_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_owner_id is not None:
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
async with self._sf() as session:
return await session.scalar(stmt) or 0
@@ -227,13 +227,13 @@ class DbRunEventStore(RunEventStore):
self,
thread_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id]
if resolved_owner_id is not None:
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
@@ -246,13 +246,13 @@ class DbRunEventStore(RunEventStore):
thread_id,
run_id,
*,
owner_id: str | None | _AutoSentinel = AUTO,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
if resolved_owner_id is not None:
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
if resolved_user_id is not None:
count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
@@ -4,8 +4,8 @@ RunManager depends on this interface. Implementations:
- MemoryRunStore: in-memory dict (development, tests)
- Future: RunRepository backed by SQLAlchemy ORM
All methods accept an optional owner_id for user isolation.
When owner_id is None, no user filtering is applied (single-user mode).
All methods accept an optional user_id for user isolation.
When user_id is None, no user filtering is applied (single-user mode).
"""
from __future__ import annotations
@@ -22,7 +22,7 @@ class RunStore(abc.ABC):
*,
thread_id: str,
assistant_id: str | None = None,
owner_id: str | None = None,
user_id: str | None = None,
status: str = "pending",
multitask_strategy: str = "reject",
metadata: dict[str, Any] | None = None,
@@ -42,7 +42,7 @@ class RunStore(abc.ABC):
self,
thread_id: str,
*,
owner_id: str | None = None,
user_id: str | None = None,
limit: int = 100,
) -> list[dict[str, Any]]:
pass
@@ -21,7 +21,7 @@ class MemoryRunStore(RunStore):
*,
thread_id,
assistant_id=None,
owner_id=None,
user_id=None,
status="pending",
multitask_strategy="reject",
metadata=None,
@@ -35,7 +35,7 @@ class MemoryRunStore(RunStore):
"run_id": run_id,
"thread_id": thread_id,
"assistant_id": assistant_id,
"owner_id": owner_id,
"user_id": user_id,
"status": status,
"multitask_strategy": multitask_strategy,
"metadata": metadata or {},
@@ -49,8 +49,8 @@ class MemoryRunStore(RunStore):
async def get(self, run_id):
return self._runs.get(run_id)
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)]
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
results.sort(key=lambda r: r["created_at"], reverse=True)
return results[:limit]
@@ -50,7 +50,7 @@ class RunContext:
store: Any | None = field(default=None)
event_store: Any | None = field(default=None)
run_events_config: Any | None = field(default=None)
thread_meta_repo: Any | None = field(default=None)
thread_store: Any | None = field(default=None)
follow_up_to_run_id: str | None = field(default=None)
@@ -75,7 +75,7 @@ async def run_agent(
store = ctx.store
event_store = ctx.event_store
run_events_config = ctx.run_events_config
thread_meta_repo = ctx.thread_meta_repo
thread_store = ctx.thread_store
follow_up_to_run_id = ctx.follow_up_to_run_id
run_id = record.run_id
@@ -376,14 +376,14 @@ async def run_agent(
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
title = ckpt.get("channel_values", {}).get("title")
if title:
await thread_meta_repo.update_display_name(thread_id, title)
await thread_store.update_display_name(thread_id, title)
except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
# Update threads_meta status based on run outcome
try:
final_status = "idle" if record.status == RunStatus.success else record.status.value
await thread_meta_repo.update_status(thread_id, final_status)
await thread_store.update_status(thread_id, final_status)
except Exception:
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
@@ -1,11 +1,11 @@
"""Request-scoped user context for owner-based authorization.
"""Request-scoped user context for user-based authorization.
This module holds a :class:`~contextvars.ContextVar` that the gateway's
auth middleware sets after a successful authentication. Repository
methods read the contextvar via a sentinel default parameter, letting
routers stay free of ``owner_id`` boilerplate.
routers stay free of ``user_id`` boilerplate.
Three-state semantics for the repository ``owner_id`` parameter (the
Three-state semantics for the repository ``user_id`` parameter (the
consumer side of this module lives in ``deerflow.persistence.*``):
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
@@ -91,16 +91,16 @@ def require_current_user() -> CurrentUser:
# ---------------------------------------------------------------------------
# Sentinel-based owner_id resolution
# Sentinel-based user_id resolution
# ---------------------------------------------------------------------------
#
# Repository methods accept an ``owner_id`` keyword-only argument that
# Repository methods accept a ``user_id`` keyword-only argument that
# defaults to ``AUTO``. The three possible values drive distinct
# behaviours; see the docstring on :func:`resolve_owner_id`.
# behaviours; see the docstring on :func:`resolve_user_id`.
class _AutoSentinel:
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
"""Singleton marker meaning 'resolve user_id from contextvar'."""
_instance: _AutoSentinel | None = None
@@ -116,12 +116,12 @@ class _AutoSentinel:
AUTO: Final[_AutoSentinel] = _AutoSentinel()
def resolve_owner_id(
def resolve_user_id(
value: str | None | _AutoSentinel,
*,
method_name: str = "repository method",
) -> str | None:
"""Resolve the owner_id parameter passed to a repository method.
"""Resolve the user_id parameter passed to a repository method.
Three-state semantics:
@@ -131,16 +131,16 @@ def resolve_owner_id(
- Explicit ``str``: use the provided id verbatim, overriding any
contextvar value. Useful for tests and admin-override flows.
- Explicit ``None``: no filter — the repository should skip the
owner_id WHERE clause entirely. Reserved for migration scripts
user_id WHERE clause entirely. Reserved for migration scripts
and CLI tools that intentionally bypass isolation.
"""
if isinstance(value, _AutoSentinel):
user = _current_user.get()
if user is None:
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.")
raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
# ``UUID`` for the API surface, but the persistence layer
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot
# stores ``user_id`` as ``String(64)`` and aiosqlite cannot
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
# not supported"). Honour the documented return type here
# rather than ripple a type change through every caller.