mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
refactor(persistence): rename owner_id to user_id and thread_meta_repo to thread_store
Rename owner_id to user_id across all persistence models, repositories, stores, routers, and tests for clearer semantics. Rename thread_meta_repo to thread_store for consistency with run_store/run_event_store naming. Add ThreadMetaStore return type annotation to get_thread_store(). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -42,6 +42,11 @@ logger = logging.getLogger(__name__)
|
|||||||
async def _ensure_admin_user(app: FastAPI) -> None:
|
async def _ensure_admin_user(app: FastAPI) -> None:
|
||||||
"""Startup hook: generate init token on first boot; migrate orphan threads otherwise.
|
"""Startup hook: generate init token on first boot; migrate orphan threads otherwise.
|
||||||
|
|
||||||
|
After admin creation, migrate orphan threads from the LangGraph
|
||||||
|
store (metadata.user_id unset) to the admin account. This is the
|
||||||
|
"no-auth → with-auth" upgrade path: users who ran DeerFlow without
|
||||||
|
authentication have existing LangGraph thread data that needs an
|
||||||
|
owner assigned.
|
||||||
First boot (no admin exists):
|
First boot (no admin exists):
|
||||||
- Generates a one-time ``init_token`` stored in ``app.state.init_token``
|
- Generates a one-time ``init_token`` stored in ``app.state.init_token``
|
||||||
- Logs the token to stdout so the operator can copy-paste it into the
|
- Logs the token to stdout so the operator can copy-paste it into the
|
||||||
@@ -52,7 +57,7 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
|||||||
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
- Runs the one-time "no-auth → with-auth" orphan thread migration for
|
||||||
existing LangGraph thread metadata that has no owner_id.
|
existing LangGraph thread metadata that has no owner_id.
|
||||||
|
|
||||||
No SQL persistence migration is needed: the four owner_id columns
|
No SQL persistence migration is needed: the four user_id columns
|
||||||
(threads_meta, runs, run_events, feedback) only come into existence
|
(threads_meta, runs, run_events, feedback) only come into existence
|
||||||
alongside the auth module via create_all, so freshly created tables
|
alongside the auth module via create_all, so freshly created tables
|
||||||
never contain NULL-owner rows.
|
never contain NULL-owner rows.
|
||||||
@@ -96,6 +101,8 @@ async def _ensure_admin_user(app: FastAPI) -> None:
|
|||||||
admin_id = str(row.id)
|
admin_id = str(row.id)
|
||||||
|
|
||||||
# LangGraph store orphan migration — non-fatal.
|
# LangGraph store orphan migration — non-fatal.
|
||||||
|
# This covers the "no-auth → with-auth" upgrade path for users
|
||||||
|
# whose existing LangGraph thread metadata has no user_id set.
|
||||||
store = getattr(app.state, "store", None)
|
store = getattr(app.state, "store", None)
|
||||||
if store is not None:
|
if store is not None:
|
||||||
try:
|
try:
|
||||||
@@ -127,7 +134,7 @@ async def _iter_store_items(store, namespace, *, page_size: int = 500):
|
|||||||
|
|
||||||
|
|
||||||
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
||||||
"""Migrate LangGraph store threads with no owner_id to the given admin.
|
"""Migrate LangGraph store threads with no user_id to the given admin.
|
||||||
|
|
||||||
Uses cursor pagination so all orphans are migrated regardless of
|
Uses cursor pagination so all orphans are migrated regardless of
|
||||||
count. Returns the number of rows migrated.
|
count. Returns the number of rows migrated.
|
||||||
@@ -135,8 +142,8 @@ async def _migrate_orphaned_threads(store, admin_user_id: str) -> int:
|
|||||||
migrated = 0
|
migrated = 0
|
||||||
async for item in _iter_store_items(store, ("threads",)):
|
async for item in _iter_store_items(store, ("threads",)):
|
||||||
metadata = item.value.get("metadata", {})
|
metadata = item.value.get("metadata", {})
|
||||||
if not metadata.get("owner_id"):
|
if not metadata.get("user_id"):
|
||||||
metadata["owner_id"] = admin_user_id
|
metadata["user_id"] = admin_user_id
|
||||||
item.value["metadata"] = metadata
|
item.value["metadata"] = metadata
|
||||||
await store.aput(("threads",), item.key, item.value)
|
await store.aput(("threads",), item.key, item.value)
|
||||||
migrated += 1
|
migrated += 1
|
||||||
|
|||||||
@@ -233,18 +233,18 @@ def require_permission(
|
|||||||
# (``threads_meta`` table). We verify ownership via
|
# (``threads_meta`` table). We verify ownership via
|
||||||
# ``ThreadMetaStore.check_access``: it returns True for
|
# ``ThreadMetaStore.check_access``: it returns True for
|
||||||
# missing rows (untracked legacy thread) and for rows whose
|
# missing rows (untracked legacy thread) and for rows whose
|
||||||
# ``owner_id`` is NULL (shared / pre-auth data), so this is
|
# ``user_id`` is NULL (shared / pre-auth data), so this is
|
||||||
# strict-deny rather than strict-allow — only an *existing*
|
# strict-deny rather than strict-allow — only an *existing*
|
||||||
# row with a *different* owner_id triggers 404.
|
# row with a *different* user_id triggers 404.
|
||||||
if owner_check:
|
if owner_check:
|
||||||
thread_id = kwargs.get("thread_id")
|
thread_id = kwargs.get("thread_id")
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||||
|
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
allowed = await thread_meta_repo.check_access(
|
allowed = await thread_store.check_access(
|
||||||
thread_id,
|
thread_id,
|
||||||
str(auth.user.id),
|
str(auth.user.id),
|
||||||
require_existing=require_existing,
|
require_existing=require_existing,
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Centralized accessors for singleton objects stored on ``app.state``.
|
"""Centralized accessors for singleton objects stored on ``app.state``.
|
||||||
|
|
||||||
**Getters** (used by routers): raise 503 when a required dependency is
|
**Getters** (used by routers): raise 503 when a required dependency is
|
||||||
missing, except ``get_store`` and ``get_thread_meta_repo`` which return
|
missing, except ``get_store`` which returns ``None``.
|
||||||
``None``.
|
|
||||||
|
|
||||||
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
Initialization is handled directly in ``app.py`` via :class:`AsyncExitStack`.
|
||||||
"""
|
"""
|
||||||
@@ -20,6 +19,7 @@ from deerflow.runtime import RunContext, RunManager
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.gateway.auth.local_provider import LocalAuthProvider
|
from app.gateway.auth.local_provider import LocalAuthProvider
|
||||||
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
from app.gateway.auth.repositories.sqlite import SQLiteUserRepository
|
||||||
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
@@ -31,10 +31,10 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
async with langgraph_runtime(app):
|
async with langgraph_runtime(app):
|
||||||
yield
|
yield
|
||||||
"""
|
"""
|
||||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
|
||||||
from deerflow.config import get_app_config
|
from deerflow.config import get_app_config
|
||||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine_from_config
|
||||||
from deerflow.runtime import make_store, make_stream_bridge
|
from deerflow.runtime import make_store, make_stream_bridge
|
||||||
|
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||||
from deerflow.runtime.events.store import make_run_event_store
|
from deerflow.runtime.events.store import make_run_event_store
|
||||||
|
|
||||||
async with AsyncExitStack() as stack:
|
async with AsyncExitStack() as stack:
|
||||||
@@ -53,18 +53,18 @@ async def langgraph_runtime(app: FastAPI) -> AsyncGenerator[None, None]:
|
|||||||
if sf is not None:
|
if sf is not None:
|
||||||
from deerflow.persistence.feedback import FeedbackRepository
|
from deerflow.persistence.feedback import FeedbackRepository
|
||||||
from deerflow.persistence.run import RunRepository
|
from deerflow.persistence.run import RunRepository
|
||||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
|
||||||
|
|
||||||
app.state.run_store = RunRepository(sf)
|
app.state.run_store = RunRepository(sf)
|
||||||
app.state.feedback_repo = FeedbackRepository(sf)
|
app.state.feedback_repo = FeedbackRepository(sf)
|
||||||
app.state.thread_meta_repo = ThreadMetaRepository(sf)
|
|
||||||
else:
|
else:
|
||||||
from deerflow.persistence.thread_meta import MemoryThreadMetaStore
|
|
||||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
|
||||||
app.state.run_store = MemoryRunStore()
|
app.state.run_store = MemoryRunStore()
|
||||||
app.state.feedback_repo = None
|
app.state.feedback_repo = None
|
||||||
app.state.thread_meta_repo = MemoryThreadMetaStore(app.state.store)
|
|
||||||
|
from deerflow.persistence.thread_meta import make_thread_store
|
||||||
|
|
||||||
|
app.state.thread_store = make_thread_store(sf, app.state.store)
|
||||||
|
|
||||||
# Run event store (has its own factory with config-driven backend selection)
|
# Run event store (has its own factory with config-driven backend selection)
|
||||||
run_events_config = getattr(config, "run_events", None)
|
run_events_config = getattr(config, "run_events", None)
|
||||||
@@ -110,7 +110,12 @@ def get_store(request: Request):
|
|||||||
return getattr(request.app.state, "store", None)
|
return getattr(request.app.state, "store", None)
|
||||||
|
|
||||||
|
|
||||||
get_thread_meta_repo = _require("thread_meta_repo", "Thread metadata store")
|
def get_thread_store(request: Request) -> ThreadMetaStore:
|
||||||
|
"""Return the thread metadata store (SQL or memory-backed)."""
|
||||||
|
val = getattr(request.app.state, "thread_store", None)
|
||||||
|
if val is None:
|
||||||
|
raise HTTPException(status_code=503, detail="Thread metadata store not available")
|
||||||
|
return val
|
||||||
|
|
||||||
|
|
||||||
def get_run_context(request: Request) -> RunContext:
|
def get_run_context(request: Request) -> RunContext:
|
||||||
@@ -128,7 +133,7 @@ def get_run_context(request: Request) -> RunContext:
|
|||||||
store=get_store(request),
|
store=get_store(request),
|
||||||
event_store=get_run_event_store(request),
|
event_store=get_run_event_store(request),
|
||||||
run_events_config=getattr(get_app_config(), "run_events", None),
|
run_events_config=getattr(get_app_config(), "run_events", None),
|
||||||
thread_meta_repo=get_thread_meta_repo(request),
|
thread_store=get_thread_store(request),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -93,14 +93,14 @@ async def authenticate(request):
|
|||||||
|
|
||||||
@auth.on
|
@auth.on
|
||||||
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
async def add_owner_filter(ctx: Auth.types.AuthContext, value: dict):
|
||||||
"""Inject owner_id metadata on writes; filter by owner_id on reads.
|
"""Inject user_id metadata on writes; filter by user_id on reads.
|
||||||
|
|
||||||
Gateway stores thread ownership as ``metadata.owner_id``.
|
Gateway stores thread ownership as ``metadata.user_id``.
|
||||||
This handler ensures LangGraph Server enforces the same isolation.
|
This handler ensures LangGraph Server enforces the same isolation.
|
||||||
"""
|
"""
|
||||||
# On create/update: stamp owner_id into metadata
|
# On create/update: stamp user_id into metadata
|
||||||
metadata = value.setdefault("metadata", {})
|
metadata = value.setdefault("metadata", {})
|
||||||
metadata["owner_id"] = ctx.user.identity
|
metadata["user_id"] = ctx.user.identity
|
||||||
|
|
||||||
# Return filter dict — LangGraph applies it to search/read/delete
|
# Return filter dict — LangGraph applies it to search/read/delete
|
||||||
return {"owner_id": ctx.user.identity}
|
return {"user_id": ctx.user.identity}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ class FeedbackResponse(BaseModel):
|
|||||||
feedback_id: str
|
feedback_id: str
|
||||||
run_id: str
|
run_id: str
|
||||||
thread_id: str
|
thread_id: str
|
||||||
owner_id: str | None = None
|
user_id: str | None = None
|
||||||
message_id: str | None = None
|
message_id: str | None = None
|
||||||
rating: int
|
rating: int
|
||||||
comment: str | None = None
|
comment: str | None = None
|
||||||
@@ -80,7 +80,7 @@ async def create_feedback(
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
rating=body.rating,
|
rating=body.rating,
|
||||||
owner_id=user_id,
|
user_id=user_id,
|
||||||
message_id=body.message_id,
|
message_id=body.message_id,
|
||||||
comment=body.comment,
|
comment=body.comment,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ router = APIRouter(prefix="/api/threads", tags=["threads"])
|
|||||||
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
# them. Pydantic ``@field_validator("metadata")`` strips them on every
|
||||||
# inbound model below so a malicious client cannot reflect a forged
|
# inbound model below so a malicious client cannot reflect a forged
|
||||||
# owner identity through the API surface. Defense-in-depth — the
|
# owner identity through the API surface. Defense-in-depth — the
|
||||||
# row-level invariant is still ``threads_meta.owner_id`` populated from
|
# row-level invariant is still ``threads_meta.user_id`` populated from
|
||||||
# the auth contextvar; this list closes the metadata-blob echo gap.
|
# the auth contextvar; this list closes the metadata-blob echo gap.
|
||||||
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
_SERVER_RESERVED_METADATA_KEYS: frozenset[str] = frozenset({"owner_id", "user_id"})
|
||||||
|
|
||||||
@@ -194,7 +194,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
|||||||
and removes the thread_meta row from the configured ThreadMetaStore
|
and removes the thread_meta row from the configured ThreadMetaStore
|
||||||
(sqlite or memory).
|
(sqlite or memory).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
# Clean local filesystem
|
# Clean local filesystem
|
||||||
response = _delete_thread_data(thread_id)
|
response = _delete_thread_data(thread_id)
|
||||||
@@ -211,8 +211,8 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
|
|||||||
# Remove thread_meta row (best-effort) — required for sqlite backend
|
# Remove thread_meta row (best-effort) — required for sqlite backend
|
||||||
# so the deleted thread no longer appears in /threads/search.
|
# so the deleted thread no longer appears in /threads/search.
|
||||||
try:
|
try:
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
await thread_meta_repo.delete(thread_id)
|
await thread_store.delete(thread_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
logger.debug("Could not delete thread_meta for %s (not critical)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
@@ -227,17 +227,17 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
and an empty checkpoint (so state endpoints work immediately).
|
and an empty checkpoint (so state endpoints work immediately).
|
||||||
Idempotent: returns the existing record when ``thread_id`` already exists.
|
Idempotent: returns the existing record when ``thread_id`` already exists.
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
now = time.time()
|
now = time.time()
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||||
|
|
||||||
# Idempotency: return existing record when already present
|
# Idempotency: return existing record when already present
|
||||||
existing_record = await thread_meta_repo.get(thread_id)
|
existing_record = await thread_store.get(thread_id)
|
||||||
if existing_record is not None:
|
if existing_record is not None:
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -249,7 +249,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
|
|
||||||
# Write thread_meta so the thread appears in /threads/search immediately
|
# Write thread_meta so the thread appears in /threads/search immediately
|
||||||
try:
|
try:
|
||||||
await thread_meta_repo.create(
|
await thread_store.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=getattr(body, "assistant_id", None),
|
assistant_id=getattr(body, "assistant_id", None),
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
@@ -293,9 +293,9 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
|||||||
Delegates to the configured ThreadMetaStore implementation
|
Delegates to the configured ThreadMetaStore implementation
|
||||||
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
(SQL-backed for sqlite/postgres, Store-backed for memory mode).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
repo = get_thread_meta_repo(request)
|
repo = get_thread_store(request)
|
||||||
rows = await repo.search(
|
rows = await repo.search(
|
||||||
metadata=body.metadata or None,
|
metadata=body.metadata or None,
|
||||||
status=body.status,
|
status=body.status,
|
||||||
@@ -320,22 +320,22 @@ async def search_threads(body: ThreadSearchRequest, request: Request) -> list[Th
|
|||||||
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
@require_permission("threads", "write", owner_check=True, require_existing=True)
|
||||||
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Request) -> ThreadResponse:
|
||||||
"""Merge metadata into a thread record."""
|
"""Merge metadata into a thread record."""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
record = await thread_meta_repo.get(thread_id)
|
record = await thread_store.get(thread_id)
|
||||||
if record is None:
|
if record is None:
|
||||||
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
raise HTTPException(status_code=404, detail=f"Thread {thread_id} not found")
|
||||||
|
|
||||||
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
# ``body.metadata`` already stripped by ``ThreadPatchRequest._strip_reserved``.
|
||||||
try:
|
try:
|
||||||
await thread_meta_repo.update_metadata(thread_id, body.metadata)
|
await thread_store.update_metadata(thread_id, body.metadata)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
logger.exception("Failed to patch thread %s", sanitize_log_param(thread_id))
|
||||||
raise HTTPException(status_code=500, detail="Failed to update thread")
|
raise HTTPException(status_code=500, detail="Failed to update thread")
|
||||||
|
|
||||||
# Re-read to get the merged metadata + refreshed updated_at
|
# Re-read to get the merged metadata + refreshed updated_at
|
||||||
record = await thread_meta_repo.get(thread_id) or record
|
record = await thread_store.get(thread_id) or record
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
status=record.get("status", "idle"),
|
status=record.get("status", "idle"),
|
||||||
@@ -354,12 +354,12 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
|
|||||||
execution status from the checkpointer. Falls back to the checkpointer
|
execution status from the checkpointer. Falls back to the checkpointer
|
||||||
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
alone for threads that pre-date ThreadMetaStore adoption (backward compat).
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
|
|
||||||
record: dict | None = await thread_meta_repo.get(thread_id)
|
record: dict | None = await thread_store.get(thread_id)
|
||||||
|
|
||||||
# Derive accurate status from the checkpointer
|
# Derive accurate status from the checkpointer
|
||||||
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
config = {"configurable": {"thread_id": thread_id, "checkpoint_ns": ""}}
|
||||||
@@ -462,10 +462,10 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
ThreadMetaStore abstraction so that ``/threads/search`` reflects the
|
||||||
change immediately in both sqlite and memory backends.
|
change immediately in both sqlite and memory backends.
|
||||||
"""
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_store
|
||||||
|
|
||||||
checkpointer = get_checkpointer(request)
|
checkpointer = get_checkpointer(request)
|
||||||
thread_meta_repo = get_thread_meta_repo(request)
|
thread_store = get_thread_store(request)
|
||||||
|
|
||||||
# checkpoint_ns must be present in the config for aput — default to ""
|
# checkpoint_ns must be present in the config for aput — default to ""
|
||||||
# (the root graph namespace). checkpoint_id is optional; omitting it
|
# (the root graph namespace). checkpoint_id is optional; omitting it
|
||||||
@@ -529,7 +529,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
|
|||||||
new_title = body.values["title"]
|
new_title = body.values["title"]
|
||||||
if new_title: # Skip empty strings and None
|
if new_title: # Skip empty strings and None
|
||||||
try:
|
try:
|
||||||
await thread_meta_repo.update_display_name(thread_id, new_title)
|
await thread_store.update_display_name(thread_id, new_title)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
logger.debug("Failed to sync title to thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
|
|||||||
@@ -229,15 +229,15 @@ async def start_run(
|
|||||||
# even for threads that were never explicitly created via POST /threads
|
# even for threads that were never explicitly created via POST /threads
|
||||||
# (e.g. stateless runs).
|
# (e.g. stateless runs).
|
||||||
try:
|
try:
|
||||||
existing = await run_ctx.thread_meta_repo.get(thread_id)
|
existing = await run_ctx.thread_store.get(thread_id)
|
||||||
if existing is None:
|
if existing is None:
|
||||||
await run_ctx.thread_meta_repo.create(
|
await run_ctx.thread_store.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=body.assistant_id,
|
assistant_id=body.assistant_id,
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
await run_ctx.thread_meta_repo.update_status(thread_id, "running")
|
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
@@ -285,7 +285,7 @@ async def start_run(
|
|||||||
record.task = task
|
record.task = task
|
||||||
|
|
||||||
# Title sync is handled by worker.py's finally block which reads the
|
# Title sync is handled by worker.py's finally block which reads the
|
||||||
# title from the checkpoint and calls thread_meta_repo.update_display_name
|
# title from the checkpoint and calls thread_store.update_display_name
|
||||||
# after the run completes.
|
# after the run completes.
|
||||||
|
|
||||||
return record
|
return record
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class FeedbackRow(Base):
|
|||||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||||
message_id: Mapped[str | None] = mapped_column(String(64))
|
message_id: Mapped[str | None] = mapped_column(String(64))
|
||||||
# message_id is an optional RunEventStore event identifier —
|
# message_id is an optional RunEventStore event identifier —
|
||||||
# allows feedback to target a specific message or the entire run
|
# allows feedback to target a specific message or the entire run
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ from sqlalchemy import case, func, select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
from deerflow.persistence.feedback.model import FeedbackRow
|
from deerflow.persistence.feedback.model import FeedbackRow
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||||
|
|
||||||
|
|
||||||
class FeedbackRepository:
|
class FeedbackRepository:
|
||||||
@@ -33,19 +33,19 @@ class FeedbackRepository:
|
|||||||
run_id: str,
|
run_id: str,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
rating: int,
|
rating: int,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
message_id: str | None = None,
|
message_id: str | None = None,
|
||||||
comment: str | None = None,
|
comment: str | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Create a feedback record. rating must be +1 or -1."""
|
"""Create a feedback record. rating must be +1 or -1."""
|
||||||
if rating not in (1, -1):
|
if rating not in (1, -1):
|
||||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.create")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.create")
|
||||||
row = FeedbackRow(
|
row = FeedbackRow(
|
||||||
feedback_id=str(uuid.uuid4()),
|
feedback_id=str(uuid.uuid4()),
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
owner_id=resolved_owner_id,
|
user_id=resolved_user_id,
|
||||||
message_id=message_id,
|
message_id=message_id,
|
||||||
rating=rating,
|
rating=rating,
|
||||||
comment=comment,
|
comment=comment,
|
||||||
@@ -61,14 +61,14 @@ class FeedbackRepository:
|
|||||||
self,
|
self,
|
||||||
feedback_id: str,
|
feedback_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.get")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(FeedbackRow, feedback_id)
|
row = await session.get(FeedbackRow, feedback_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
@@ -78,12 +78,12 @@ class FeedbackRepository:
|
|||||||
run_id: str,
|
run_id: str,
|
||||||
*,
|
*,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_run")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_run")
|
||||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -94,12 +94,12 @@ class FeedbackRepository:
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.list_by_thread")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread")
|
||||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(FeedbackRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||||
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
stmt = stmt.order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -109,14 +109,14 @@ class FeedbackRepository:
|
|||||||
self,
|
self,
|
||||||
feedback_id: str,
|
feedback_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="FeedbackRepository.delete")
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(FeedbackRow, feedback_id)
|
row = await session.get(FeedbackRow, feedback_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return False
|
return False
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return False
|
return False
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ class RunEventRow(Base):
|
|||||||
# Owner of the conversation this event belongs to. Nullable for data
|
# Owner of the conversation this event belongs to. Nullable for data
|
||||||
# created before auth was introduced; populated by auth middleware on
|
# created before auth was introduced; populated by auth middleware on
|
||||||
# new writes and by the boot-time orphan migration on existing rows.
|
# new writes and by the boot-time orphan migration on existing rows.
|
||||||
owner_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
user_id: Mapped[str | None] = mapped_column(String(64), nullable=True, index=True)
|
||||||
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
event_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
||||||
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
category: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||||
# "message" | "trace" | "lifecycle"
|
# "message" | "trace" | "lifecycle"
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ class RunRow(Base):
|
|||||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||||
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
assistant_id: Mapped[str | None] = mapped_column(String(128))
|
||||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||||
status: Mapped[str] = mapped_column(String(20), default="pending")
|
status: Mapped[str] = mapped_column(String(20), default="pending")
|
||||||
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
# "pending" | "running" | "success" | "error" | "timeout" | "interrupted"
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.run.model import RunRow
|
from deerflow.persistence.run.model import RunRow
|
||||||
from deerflow.runtime.runs.store.base import RunStore
|
from deerflow.runtime.runs.store.base import RunStore
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||||
|
|
||||||
|
|
||||||
class RunRepository(RunStore):
|
class RunRepository(RunStore):
|
||||||
@@ -69,7 +69,7 @@ class RunRepository(RunStore):
|
|||||||
*,
|
*,
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=None,
|
assistant_id=None,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
status="pending",
|
status="pending",
|
||||||
multitask_strategy="reject",
|
multitask_strategy="reject",
|
||||||
metadata=None,
|
metadata=None,
|
||||||
@@ -78,13 +78,13 @@ class RunRepository(RunStore):
|
|||||||
created_at=None,
|
created_at=None,
|
||||||
follow_up_to_run_id=None,
|
follow_up_to_run_id=None,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.put")
|
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
row = RunRow(
|
row = RunRow(
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
owner_id=resolved_owner_id,
|
user_id=resolved_user_id,
|
||||||
status=status,
|
status=status,
|
||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata_json=self._safe_json(metadata) or {},
|
metadata_json=self._safe_json(metadata) or {},
|
||||||
@@ -102,14 +102,14 @@ class RunRepository(RunStore):
|
|||||||
self,
|
self,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.get")
|
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(RunRow, run_id)
|
row = await session.get(RunRow, run_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
@@ -117,13 +117,13 @@ class RunRepository(RunStore):
|
|||||||
self,
|
self,
|
||||||
thread_id,
|
thread_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
limit=100,
|
limit=100,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.list_by_thread")
|
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.list_by_thread")
|
||||||
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
stmt = select(RunRow).where(RunRow.thread_id == thread_id)
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(RunRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(RunRow.user_id == resolved_user_id)
|
||||||
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
stmt = stmt.order_by(RunRow.created_at.desc()).limit(limit)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -141,14 +141,14 @@ class RunRepository(RunStore):
|
|||||||
self,
|
self,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="RunRepository.delete")
|
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(RunRow, run_id)
|
row = await session.get(RunRow, run_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return
|
return
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ class ThreadMetaRow(Base):
|
|||||||
|
|
||||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
assistant_id: Mapped[str | None] = mapped_column(String(128), index=True)
|
||||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
user_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
from deerflow.persistence.thread_meta.base import ThreadMetaStore
|
||||||
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
from deerflow.persistence.thread_meta.model import ThreadMetaRow
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_owner_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, resolve_user_id
|
||||||
|
|
||||||
|
|
||||||
class ThreadMetaRepository(ThreadMetaStore):
|
class ThreadMetaRepository(ThreadMetaStore):
|
||||||
@@ -32,18 +32,18 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
display_name: str | None = None,
|
display_name: str | None = None,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
# Auto-resolve owner_id from contextvar when AUTO; explicit None
|
# Auto-resolve user_id from contextvar when AUTO; explicit None
|
||||||
# creates an orphan row (used by migration scripts).
|
# creates an orphan row (used by migration scripts).
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.create")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.create")
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
row = ThreadMetaRow(
|
row = ThreadMetaRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
assistant_id=assistant_id,
|
assistant_id=assistant_id,
|
||||||
owner_id=resolved_owner_id,
|
user_id=resolved_user_id,
|
||||||
display_name=display_name,
|
display_name=display_name,
|
||||||
metadata_json=metadata or {},
|
metadata_json=metadata or {},
|
||||||
created_at=now,
|
created_at=now,
|
||||||
@@ -59,40 +59,34 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> dict | None:
|
) -> dict | None:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.get")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.get")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return None
|
return None
|
||||||
# Enforce owner filter unless explicitly bypassed (owner_id=None).
|
# Enforce owner filter unless explicitly bypassed (user_id=None).
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return None
|
return None
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
async def list_by_owner(self, owner_id: str, *, limit: int = 100, offset: int = 0) -> list[dict]:
|
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||||
stmt = select(ThreadMetaRow).where(ThreadMetaRow.owner_id == owner_id).order_by(ThreadMetaRow.updated_at.desc()).limit(limit).offset(offset)
|
"""Check if ``user_id`` has access to ``thread_id``.
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
|
||||||
|
|
||||||
async def check_access(self, thread_id: str, owner_id: str, *, require_existing: bool = False) -> bool:
|
|
||||||
"""Check if ``owner_id`` has access to ``thread_id``.
|
|
||||||
|
|
||||||
Two modes — one row, two distinct semantics depending on what
|
Two modes — one row, two distinct semantics depending on what
|
||||||
the caller is about to do:
|
the caller is about to do:
|
||||||
|
|
||||||
- ``require_existing=False`` (default, permissive):
|
- ``require_existing=False`` (default, permissive):
|
||||||
Returns True for: row missing (untracked legacy thread),
|
Returns True for: row missing (untracked legacy thread),
|
||||||
``row.owner_id`` is None (shared / pre-auth data),
|
``row.user_id`` is None (shared / pre-auth data),
|
||||||
or ``row.owner_id == owner_id``. Use for **read-style**
|
or ``row.user_id == user_id``. Use for **read-style**
|
||||||
decorators where treating an untracked thread as accessible
|
decorators where treating an untracked thread as accessible
|
||||||
preserves backward-compat.
|
preserves backward-compat.
|
||||||
|
|
||||||
- ``require_existing=True`` (strict):
|
- ``require_existing=True`` (strict):
|
||||||
Returns True **only** when the row exists AND
|
Returns True **only** when the row exists AND
|
||||||
(``row.owner_id == owner_id`` OR ``row.owner_id is None``).
|
(``row.user_id == user_id`` OR ``row.user_id is None``).
|
||||||
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
Use for **destructive / mutating** decorators (DELETE, PATCH,
|
||||||
state-update) so a thread that has *already been deleted*
|
state-update) so a thread that has *already been deleted*
|
||||||
cannot be re-targeted by any caller — closing the
|
cannot be re-targeted by any caller — closing the
|
||||||
@@ -103,9 +97,9 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return not require_existing
|
return not require_existing
|
||||||
if row.owner_id is None:
|
if row.user_id is None:
|
||||||
return True
|
return True
|
||||||
return row.owner_id == owner_id
|
return row.user_id == user_id
|
||||||
|
|
||||||
async def search(
|
async def search(
|
||||||
self,
|
self,
|
||||||
@@ -114,17 +108,17 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
offset: int = 0,
|
offset: int = 0,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Search threads with optional metadata and status filters.
|
"""Search threads with optional metadata and status filters.
|
||||||
|
|
||||||
Owner filter is enforced by default: caller must be in a user
|
Owner filter is enforced by default: caller must be in a user
|
||||||
context. Pass ``owner_id=None`` to bypass (migration/CLI).
|
context. Pass ``user_id=None`` to bypass (migration/CLI).
|
||||||
"""
|
"""
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.search")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.search")
|
||||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(ThreadMetaRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(ThreadMetaRow.user_id == resolved_user_id)
|
||||||
if status:
|
if status:
|
||||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||||
|
|
||||||
@@ -144,24 +138,24 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_owner_id: str | None) -> bool:
|
async def _check_ownership(self, session: AsyncSession, thread_id: str, resolved_user_id: str | None) -> bool:
|
||||||
"""Return True if the row exists and is owned (or filter bypassed)."""
|
"""Return True if the row exists and is owned (or filter bypassed)."""
|
||||||
if resolved_owner_id is None:
|
if resolved_user_id is None:
|
||||||
return True # explicit bypass
|
return True # explicit bypass
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
return row is not None and row.owner_id == resolved_owner_id
|
return row is not None and row.user_id == resolved_user_id
|
||||||
|
|
||||||
async def update_display_name(
|
async def update_display_name(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
display_name: str,
|
display_name: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update the display_name (title) for a thread."""
|
"""Update the display_name (title) for a thread."""
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_display_name")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_display_name")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||||
return
|
return
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -171,11 +165,11 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
status: str,
|
status: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_status")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_status")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
if not await self._check_ownership(session, thread_id, resolved_owner_id):
|
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||||
return
|
return
|
||||||
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(status=status, updated_at=datetime.now(UTC)))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -185,20 +179,20 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
thread_id: str,
|
thread_id: str,
|
||||||
metadata: dict,
|
metadata: dict,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Merge ``metadata`` into ``metadata_json``.
|
"""Merge ``metadata`` into ``metadata_json``.
|
||||||
|
|
||||||
Read-modify-write inside a single session/transaction so concurrent
|
Read-modify-write inside a single session/transaction so concurrent
|
||||||
callers see consistent state. No-op if the row does not exist or
|
callers see consistent state. No-op if the row does not exist or
|
||||||
the owner_id check fails.
|
the user_id check fails.
|
||||||
"""
|
"""
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.update_metadata")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_metadata")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return
|
return
|
||||||
merged = dict(row.metadata_json or {})
|
merged = dict(row.metadata_json or {})
|
||||||
merged.update(metadata)
|
merged.update(metadata)
|
||||||
@@ -210,14 +204,14 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
) -> None:
|
) -> None:
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="ThreadMetaRepository.delete")
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.delete")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
row = await session.get(ThreadMetaRow, thread_id)
|
row = await session.get(ThreadMetaRow, thread_id)
|
||||||
if row is None:
|
if row is None:
|
||||||
return
|
return
|
||||||
if resolved_owner_id is not None and row.owner_id != resolved_owner_id:
|
if resolved_user_id is not None and row.user_id != resolved_user_id:
|
||||||
return
|
return
|
||||||
await session.delete(row)
|
await session.delete(row)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|||||||
@@ -24,12 +24,12 @@ from collections.abc import AsyncIterator
|
|||||||
|
|
||||||
from langgraph.types import Checkpointer
|
from langgraph.types import Checkpointer
|
||||||
|
|
||||||
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.runtime.checkpointer.provider import (
|
from deerflow.runtime.checkpointer.provider import (
|
||||||
POSTGRES_CONN_REQUIRED,
|
POSTGRES_CONN_REQUIRED,
|
||||||
POSTGRES_INSTALL,
|
POSTGRES_INSTALL,
|
||||||
SQLITE_INSTALL,
|
SQLITE_INSTALL,
|
||||||
)
|
)
|
||||||
from deerflow.config.app_config import get_app_config
|
|
||||||
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|||||||
|
|
||||||
from deerflow.persistence.models.run_event import RunEventRow
|
from deerflow.persistence.models.run_event import RunEventRow
|
||||||
from deerflow.runtime.events.store.base import RunEventStore
|
from deerflow.runtime.events.store.base import RunEventStore
|
||||||
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_owner_id
|
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -55,8 +55,8 @@ class DbRunEventStore(RunEventStore):
|
|||||||
return content, metadata or {}
|
return content, metadata or {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _owner_from_context() -> str | None:
|
def _user_id_from_context() -> str | None:
|
||||||
"""Soft read of owner_id from contextvar for write paths.
|
"""Soft read of user_id from contextvar for write paths.
|
||||||
|
|
||||||
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
Returns ``None`` (no filter / no stamp) if contextvar is unset,
|
||||||
which is the expected case for background worker writes. HTTP
|
which is the expected case for background worker writes. HTTP
|
||||||
@@ -81,7 +81,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||||
else:
|
else:
|
||||||
db_content = content
|
db_content = content
|
||||||
owner_id = self._owner_from_context()
|
user_id = self._user_id_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||||
@@ -92,7 +92,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
owner_id=owner_id,
|
user_id=user_id,
|
||||||
event_type=event_type,
|
event_type=event_type,
|
||||||
category=category,
|
category=category,
|
||||||
content=db_content,
|
content=db_content,
|
||||||
@@ -106,7 +106,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
async def put_batch(self, events):
|
async def put_batch(self, events):
|
||||||
if not events:
|
if not events:
|
||||||
return []
|
return []
|
||||||
owner_id = self._owner_from_context()
|
user_id = self._user_id_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||||
@@ -130,7 +130,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=e["thread_id"],
|
thread_id=e["thread_id"],
|
||||||
run_id=e["run_id"],
|
run_id=e["run_id"],
|
||||||
owner_id=e.get("owner_id", owner_id),
|
user_id=e.get("user_id", user_id),
|
||||||
event_type=e["event_type"],
|
event_type=e["event_type"],
|
||||||
category=category,
|
category=category,
|
||||||
content=db_content,
|
content=db_content,
|
||||||
@@ -149,12 +149,12 @@ class DbRunEventStore(RunEventStore):
|
|||||||
limit=50,
|
limit=50,
|
||||||
before_seq=None,
|
before_seq=None,
|
||||||
after_seq=None,
|
after_seq=None,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||||
if before_seq is not None:
|
if before_seq is not None:
|
||||||
stmt = stmt.where(RunEventRow.seq < before_seq)
|
stmt = stmt.where(RunEventRow.seq < before_seq)
|
||||||
if after_seq is not None:
|
if after_seq is not None:
|
||||||
@@ -181,12 +181,12 @@ class DbRunEventStore(RunEventStore):
|
|||||||
*,
|
*,
|
||||||
event_types=None,
|
event_types=None,
|
||||||
limit=500,
|
limit=500,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_events")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||||
if event_types:
|
if event_types:
|
||||||
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
|
||||||
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
|
||||||
@@ -199,12 +199,12 @@ class DbRunEventStore(RunEventStore):
|
|||||||
thread_id,
|
thread_id,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.list_messages_by_run")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
|
||||||
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id, RunEventRow.category == "message")
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||||
stmt = stmt.order_by(RunEventRow.seq.asc())
|
stmt = stmt.order_by(RunEventRow.seq.asc())
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
@@ -214,12 +214,12 @@ class DbRunEventStore(RunEventStore):
|
|||||||
self,
|
self,
|
||||||
thread_id,
|
thread_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.count_messages")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
|
||||||
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
stmt = stmt.where(RunEventRow.owner_id == resolved_owner_id)
|
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
return await session.scalar(stmt) or 0
|
return await session.scalar(stmt) or 0
|
||||||
|
|
||||||
@@ -227,13 +227,13 @@ class DbRunEventStore(RunEventStore):
|
|||||||
self,
|
self,
|
||||||
thread_id,
|
thread_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_thread")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
count_conditions = [RunEventRow.thread_id == thread_id]
|
count_conditions = [RunEventRow.thread_id == thread_id]
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||||
count = await session.scalar(count_stmt) or 0
|
count = await session.scalar(count_stmt) or 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
@@ -246,13 +246,13 @@ class DbRunEventStore(RunEventStore):
|
|||||||
thread_id,
|
thread_id,
|
||||||
run_id,
|
run_id,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None | _AutoSentinel = AUTO,
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
):
|
):
|
||||||
resolved_owner_id = resolve_owner_id(owner_id, method_name="DbRunEventStore.delete_by_run")
|
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
|
||||||
if resolved_owner_id is not None:
|
if resolved_user_id is not None:
|
||||||
count_conditions.append(RunEventRow.owner_id == resolved_owner_id)
|
count_conditions.append(RunEventRow.user_id == resolved_user_id)
|
||||||
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
|
||||||
count = await session.scalar(count_stmt) or 0
|
count = await session.scalar(count_stmt) or 0
|
||||||
if count > 0:
|
if count > 0:
|
||||||
|
|||||||
@@ -4,8 +4,8 @@ RunManager depends on this interface. Implementations:
|
|||||||
- MemoryRunStore: in-memory dict (development, tests)
|
- MemoryRunStore: in-memory dict (development, tests)
|
||||||
- Future: RunRepository backed by SQLAlchemy ORM
|
- Future: RunRepository backed by SQLAlchemy ORM
|
||||||
|
|
||||||
All methods accept an optional owner_id for user isolation.
|
All methods accept an optional user_id for user isolation.
|
||||||
When owner_id is None, no user filtering is applied (single-user mode).
|
When user_id is None, no user filtering is applied (single-user mode).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -22,7 +22,7 @@ class RunStore(abc.ABC):
|
|||||||
*,
|
*,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
assistant_id: str | None = None,
|
assistant_id: str | None = None,
|
||||||
owner_id: str | None = None,
|
user_id: str | None = None,
|
||||||
status: str = "pending",
|
status: str = "pending",
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
metadata: dict[str, Any] | None = None,
|
metadata: dict[str, Any] | None = None,
|
||||||
@@ -42,7 +42,7 @@ class RunStore(abc.ABC):
|
|||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
*,
|
*,
|
||||||
owner_id: str | None = None,
|
user_id: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ class MemoryRunStore(RunStore):
|
|||||||
*,
|
*,
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=None,
|
assistant_id=None,
|
||||||
owner_id=None,
|
user_id=None,
|
||||||
status="pending",
|
status="pending",
|
||||||
multitask_strategy="reject",
|
multitask_strategy="reject",
|
||||||
metadata=None,
|
metadata=None,
|
||||||
@@ -35,7 +35,7 @@ class MemoryRunStore(RunStore):
|
|||||||
"run_id": run_id,
|
"run_id": run_id,
|
||||||
"thread_id": thread_id,
|
"thread_id": thread_id,
|
||||||
"assistant_id": assistant_id,
|
"assistant_id": assistant_id,
|
||||||
"owner_id": owner_id,
|
"user_id": user_id,
|
||||||
"status": status,
|
"status": status,
|
||||||
"multitask_strategy": multitask_strategy,
|
"multitask_strategy": multitask_strategy,
|
||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
@@ -49,8 +49,8 @@ class MemoryRunStore(RunStore):
|
|||||||
async def get(self, run_id):
|
async def get(self, run_id):
|
||||||
return self._runs.get(run_id)
|
return self._runs.get(run_id)
|
||||||
|
|
||||||
async def list_by_thread(self, thread_id, *, owner_id=None, limit=100):
|
async def list_by_thread(self, thread_id, *, user_id=None, limit=100):
|
||||||
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (owner_id is None or r.get("owner_id") == owner_id)]
|
results = [r for r in self._runs.values() if r["thread_id"] == thread_id and (user_id is None or r.get("user_id") == user_id)]
|
||||||
results.sort(key=lambda r: r["created_at"], reverse=True)
|
results.sort(key=lambda r: r["created_at"], reverse=True)
|
||||||
return results[:limit]
|
return results[:limit]
|
||||||
|
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ class RunContext:
|
|||||||
store: Any | None = field(default=None)
|
store: Any | None = field(default=None)
|
||||||
event_store: Any | None = field(default=None)
|
event_store: Any | None = field(default=None)
|
||||||
run_events_config: Any | None = field(default=None)
|
run_events_config: Any | None = field(default=None)
|
||||||
thread_meta_repo: Any | None = field(default=None)
|
thread_store: Any | None = field(default=None)
|
||||||
follow_up_to_run_id: str | None = field(default=None)
|
follow_up_to_run_id: str | None = field(default=None)
|
||||||
|
|
||||||
|
|
||||||
@@ -75,7 +75,7 @@ async def run_agent(
|
|||||||
store = ctx.store
|
store = ctx.store
|
||||||
event_store = ctx.event_store
|
event_store = ctx.event_store
|
||||||
run_events_config = ctx.run_events_config
|
run_events_config = ctx.run_events_config
|
||||||
thread_meta_repo = ctx.thread_meta_repo
|
thread_store = ctx.thread_store
|
||||||
follow_up_to_run_id = ctx.follow_up_to_run_id
|
follow_up_to_run_id = ctx.follow_up_to_run_id
|
||||||
|
|
||||||
run_id = record.run_id
|
run_id = record.run_id
|
||||||
@@ -376,14 +376,14 @@ async def run_agent(
|
|||||||
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
ckpt = getattr(ckpt_tuple, "checkpoint", {}) or {}
|
||||||
title = ckpt.get("channel_values", {}).get("title")
|
title = ckpt.get("channel_values", {}).get("title")
|
||||||
if title:
|
if title:
|
||||||
await thread_meta_repo.update_display_name(thread_id, title)
|
await thread_store.update_display_name(thread_id, title)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
|
||||||
|
|
||||||
# Update threads_meta status based on run outcome
|
# Update threads_meta status based on run outcome
|
||||||
try:
|
try:
|
||||||
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
final_status = "idle" if record.status == RunStatus.success else record.status.value
|
||||||
await thread_meta_repo.update_status(thread_id, final_status)
|
await thread_store.update_status(thread_id, final_status)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Request-scoped user context for owner-based authorization.
|
"""Request-scoped user context for user-based authorization.
|
||||||
|
|
||||||
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
This module holds a :class:`~contextvars.ContextVar` that the gateway's
|
||||||
auth middleware sets after a successful authentication. Repository
|
auth middleware sets after a successful authentication. Repository
|
||||||
methods read the contextvar via a sentinel default parameter, letting
|
methods read the contextvar via a sentinel default parameter, letting
|
||||||
routers stay free of ``owner_id`` boilerplate.
|
routers stay free of ``user_id`` boilerplate.
|
||||||
|
|
||||||
Three-state semantics for the repository ``owner_id`` parameter (the
|
Three-state semantics for the repository ``user_id`` parameter (the
|
||||||
consumer side of this module lives in ``deerflow.persistence.*``):
|
consumer side of this module lives in ``deerflow.persistence.*``):
|
||||||
|
|
||||||
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
- ``_AUTO`` (module-private sentinel, default): read from contextvar;
|
||||||
@@ -91,16 +91,16 @@ def require_current_user() -> CurrentUser:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Sentinel-based owner_id resolution
|
# Sentinel-based user_id resolution
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
# Repository methods accept an ``owner_id`` keyword-only argument that
|
# Repository methods accept a ``user_id`` keyword-only argument that
|
||||||
# defaults to ``AUTO``. The three possible values drive distinct
|
# defaults to ``AUTO``. The three possible values drive distinct
|
||||||
# behaviours; see the docstring on :func:`resolve_owner_id`.
|
# behaviours; see the docstring on :func:`resolve_user_id`.
|
||||||
|
|
||||||
|
|
||||||
class _AutoSentinel:
|
class _AutoSentinel:
|
||||||
"""Singleton marker meaning 'resolve owner_id from contextvar'."""
|
"""Singleton marker meaning 'resolve user_id from contextvar'."""
|
||||||
|
|
||||||
_instance: _AutoSentinel | None = None
|
_instance: _AutoSentinel | None = None
|
||||||
|
|
||||||
@@ -116,12 +116,12 @@ class _AutoSentinel:
|
|||||||
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
AUTO: Final[_AutoSentinel] = _AutoSentinel()
|
||||||
|
|
||||||
|
|
||||||
def resolve_owner_id(
|
def resolve_user_id(
|
||||||
value: str | None | _AutoSentinel,
|
value: str | None | _AutoSentinel,
|
||||||
*,
|
*,
|
||||||
method_name: str = "repository method",
|
method_name: str = "repository method",
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
"""Resolve the owner_id parameter passed to a repository method.
|
"""Resolve the user_id parameter passed to a repository method.
|
||||||
|
|
||||||
Three-state semantics:
|
Three-state semantics:
|
||||||
|
|
||||||
@@ -131,16 +131,16 @@ def resolve_owner_id(
|
|||||||
- Explicit ``str``: use the provided id verbatim, overriding any
|
- Explicit ``str``: use the provided id verbatim, overriding any
|
||||||
contextvar value. Useful for tests and admin-override flows.
|
contextvar value. Useful for tests and admin-override flows.
|
||||||
- Explicit ``None``: no filter — the repository should skip the
|
- Explicit ``None``: no filter — the repository should skip the
|
||||||
owner_id WHERE clause entirely. Reserved for migration scripts
|
user_id WHERE clause entirely. Reserved for migration scripts
|
||||||
and CLI tools that intentionally bypass isolation.
|
and CLI tools that intentionally bypass isolation.
|
||||||
"""
|
"""
|
||||||
if isinstance(value, _AutoSentinel):
|
if isinstance(value, _AutoSentinel):
|
||||||
user = _current_user.get()
|
user = _current_user.get()
|
||||||
if user is None:
|
if user is None:
|
||||||
raise RuntimeError(f"{method_name} called with owner_id=AUTO but no user context is set; pass an explicit owner_id, set the contextvar via auth middleware, or opt out with owner_id=None for migration/CLI paths.")
|
raise RuntimeError(f"{method_name} called with user_id=AUTO but no user context is set; pass an explicit user_id, set the contextvar via auth middleware, or opt out with user_id=None for migration/CLI paths.")
|
||||||
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
# Coerce to ``str`` at the boundary: ``User.id`` is typed as
|
||||||
# ``UUID`` for the API surface, but the persistence layer
|
# ``UUID`` for the API surface, but the persistence layer
|
||||||
# stores ``owner_id`` as ``String(64)`` and aiosqlite cannot
|
# stores ``user_id`` as ``String(64)`` and aiosqlite cannot
|
||||||
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
# bind a raw UUID object to a VARCHAR column ("type 'UUID' is
|
||||||
# not supported"). Honour the documented return type here
|
# not supported"). Honour the documented return type here
|
||||||
# rather than ripple a type change through every caller.
|
# rather than ripple a type change through every caller.
|
||||||
|
|||||||
@@ -3,16 +3,16 @@
|
|||||||
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
The production gateway runs ``AuthMiddleware`` (validates the JWT cookie)
|
||||||
ahead of every router, plus ``@require_permission(owner_check=True)``
|
ahead of every router, plus ``@require_permission(owner_check=True)``
|
||||||
decorators that read ``request.state.auth`` and call
|
decorators that read ``request.state.auth`` and call
|
||||||
``thread_meta_repo.check_access``. Router-level unit tests construct
|
``thread_store.check_access``. Router-level unit tests construct
|
||||||
**bare** FastAPI apps that include only one router — they have neither
|
**bare** FastAPI apps that include only one router — they have neither
|
||||||
the auth middleware nor a real thread_meta_repo, so the decorators raise
|
the auth middleware nor a real thread_store, so the decorators raise
|
||||||
401 (TestClient path) or ValueError (direct-call path).
|
401 (TestClient path) or ValueError (direct-call path).
|
||||||
|
|
||||||
This module provides two surfaces:
|
This module provides two surfaces:
|
||||||
|
|
||||||
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
1. :func:`make_authed_test_app` — wraps ``FastAPI()`` with a tiny
|
||||||
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
``BaseHTTPMiddleware`` that stamps a fake user / AuthContext on every
|
||||||
request, plus a permissive ``thread_meta_repo`` mock on
|
request, plus a permissive ``thread_store`` mock on
|
||||||
``app.state``. Use from TestClient-based router tests.
|
``app.state``. Use from TestClient-based router tests.
|
||||||
|
|
||||||
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
2. :func:`call_unwrapped` — invokes the underlying function bypassing
|
||||||
@@ -86,20 +86,20 @@ def make_authed_test_app(
|
|||||||
user_factory: Callable[[], User] | None = None,
|
user_factory: Callable[[], User] | None = None,
|
||||||
owner_check_passes: bool = True,
|
owner_check_passes: bool = True,
|
||||||
) -> FastAPI:
|
) -> FastAPI:
|
||||||
"""Build a FastAPI test app with stub auth + permissive thread_meta_repo.
|
"""Build a FastAPI test app with stub auth + permissive thread_store.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user_factory: Override the default test user. Must return a fully
|
user_factory: Override the default test user. Must return a fully
|
||||||
populated :class:`User`. Useful for cross-user isolation tests
|
populated :class:`User`. Useful for cross-user isolation tests
|
||||||
that need a stable id across requests.
|
that need a stable id across requests.
|
||||||
owner_check_passes: When True (default), ``thread_meta_repo.check_access``
|
owner_check_passes: When True (default), ``thread_store.check_access``
|
||||||
returns True for every call so ``@require_permission(owner_check=True)``
|
returns True for every call so ``@require_permission(owner_check=True)``
|
||||||
never blocks the route under test. Pass False to verify that
|
never blocks the route under test. Pass False to verify that
|
||||||
permission failures surface correctly.
|
permission failures surface correctly.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ``FastAPI`` app with the stub middleware installed and
|
A ``FastAPI`` app with the stub middleware installed and
|
||||||
``app.state.thread_meta_repo`` set to a permissive mock. The
|
``app.state.thread_store`` set to a permissive mock. The
|
||||||
caller is still responsible for ``app.include_router(...)``.
|
caller is still responsible for ``app.include_router(...)``.
|
||||||
"""
|
"""
|
||||||
factory = user_factory or _make_stub_user
|
factory = user_factory or _make_stub_user
|
||||||
@@ -108,7 +108,7 @@ def make_authed_test_app(
|
|||||||
|
|
||||||
repo = MagicMock()
|
repo = MagicMock()
|
||||||
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
repo.check_access = AsyncMock(return_value=owner_check_passes)
|
||||||
app.state.thread_meta_repo = repo
|
app.state.thread_store = repo
|
||||||
|
|
||||||
return app
|
return app
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ def provisioner_module():
|
|||||||
# Auto-set user context for every test unless marked no_auto_user
|
# Auto-set user context for every test unless marked no_auto_user
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
#
|
#
|
||||||
# Repository methods read ``owner_id`` from a contextvar by default
|
# Repository methods read ``user_id`` from a contextvar by default
|
||||||
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
# (see ``deerflow.runtime.user_context``). Without this fixture, every
|
||||||
# pre-existing persistence test would raise RuntimeError because the
|
# pre-existing persistence test would raise RuntimeError because the
|
||||||
# contextvar is unset. The fixture sets a default test user on every
|
# contextvar is unset. The fixture sets a default test user on every
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
import deerflow.config.app_config as app_config_module
|
import deerflow.config.app_config as app_config_module
|
||||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
|
||||||
from deerflow.config.checkpointer_config import (
|
from deerflow.config.checkpointer_config import (
|
||||||
CheckpointerConfig,
|
CheckpointerConfig,
|
||||||
get_checkpointer_config,
|
get_checkpointer_config,
|
||||||
load_checkpointer_config_from_dict,
|
load_checkpointer_config_from_dict,
|
||||||
set_checkpointer_config,
|
set_checkpointer_config,
|
||||||
)
|
)
|
||||||
|
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
|
|||||||
@@ -199,12 +199,12 @@ def test_migration_failure_is_non_fatal():
|
|||||||
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
# ── Section 5.1-5.6 upgrade path: orphan thread migration ────────────────
|
||||||
|
|
||||||
|
|
||||||
def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
def test_migrate_orphaned_threads_stamps_user_id_on_unowned_rows():
|
||||||
"""First boot finds Store-only legacy threads → stamps admin's id.
|
"""First boot finds Store-only legacy threads → stamps admin's id.
|
||||||
|
|
||||||
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
Validates the **TC-UPG-02 upgrade story**: an operator running main
|
||||||
(no auth) accumulates threads in the LangGraph Store namespace
|
(no auth) accumulates threads in the LangGraph Store namespace
|
||||||
``("threads",)`` with no ``metadata.owner_id``. After upgrading to
|
``("threads",)`` with no ``metadata.user_id``. After upgrading to
|
||||||
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
feat/auth-on-2.0-rc, the first ``_ensure_admin_user`` boot should
|
||||||
rewrite each unowned item with the freshly created admin's id.
|
rewrite each unowned item with the freshly created admin's id.
|
||||||
"""
|
"""
|
||||||
@@ -215,7 +215,7 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
|||||||
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
SimpleNamespace(key="t1", value={"metadata": {"title": "old-thread-1"}}),
|
||||||
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
SimpleNamespace(key="t2", value={"metadata": {"title": "old-thread-2"}}),
|
||||||
SimpleNamespace(key="t3", value={"metadata": {}}),
|
SimpleNamespace(key="t3", value={"metadata": {}}),
|
||||||
SimpleNamespace(key="t4", value={"metadata": {"owner_id": "someone-else", "title": "preserved"}}),
|
SimpleNamespace(key="t4", value={"metadata": {"user_id": "someone-else", "title": "preserved"}}),
|
||||||
]
|
]
|
||||||
store = AsyncMock()
|
store = AsyncMock()
|
||||||
# asearch returns the entire batch on first call, then an empty page
|
# asearch returns the entire batch on first call, then an empty page
|
||||||
@@ -235,11 +235,11 @@ def test_migrate_orphaned_threads_stamps_owner_id_on_unowned_rows():
|
|||||||
assert len(aput_calls) == 3
|
assert len(aput_calls) == 3
|
||||||
rewritten_keys = {call[1] for call in aput_calls}
|
rewritten_keys = {call[1] for call in aput_calls}
|
||||||
assert rewritten_keys == {"t1", "t2", "t3"}
|
assert rewritten_keys == {"t1", "t2", "t3"}
|
||||||
# Each rewrite carries the new owner_id; titles preserved where present.
|
# Each rewrite carries the new user_id; titles preserved where present.
|
||||||
by_key = {call[1]: call[2] for call in aput_calls}
|
by_key = {call[1]: call[2] for call in aput_calls}
|
||||||
assert by_key["t1"]["metadata"]["owner_id"] == "admin-id-42"
|
assert by_key["t1"]["metadata"]["user_id"] == "admin-id-42"
|
||||||
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
assert by_key["t1"]["metadata"]["title"] == "old-thread-1"
|
||||||
assert by_key["t3"]["metadata"]["owner_id"] == "admin-id-42"
|
assert by_key["t3"]["metadata"]["user_id"] == "admin-id-42"
|
||||||
# The pre-owned item must NOT have been rewritten.
|
# The pre-owned item must NOT have been rewritten.
|
||||||
assert "t4" not in rewritten_keys
|
assert "t4" not in rewritten_keys
|
||||||
|
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ class TestFeedbackRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_create_with_owner(self, tmp_path):
|
async def test_create_with_owner(self, tmp_path):
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
record = await repo.create(run_id="r1", thread_id="t1", rating=1, owner_id="user-1")
|
record = await repo.create(run_id="r1", thread_id="t1", rating=1, user_id="user-1")
|
||||||
assert record["owner_id"] == "user-1"
|
assert record["user_id"] == "user-1"
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
@@ -175,46 +175,46 @@ def _make_ctx(user_id):
|
|||||||
def test_filter_injects_user_id():
|
def test_filter_injects_user_id():
|
||||||
value = {}
|
value = {}
|
||||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||||
assert value["metadata"]["owner_id"] == "user-a"
|
assert value["metadata"]["user_id"] == "user-a"
|
||||||
|
|
||||||
|
|
||||||
def test_filter_preserves_existing_metadata():
|
def test_filter_preserves_existing_metadata():
|
||||||
value = {"metadata": {"title": "hello"}}
|
value = {"metadata": {"title": "hello"}}
|
||||||
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
asyncio.run(add_owner_filter(_make_ctx("user-a"), value))
|
||||||
assert value["metadata"]["owner_id"] == "user-a"
|
assert value["metadata"]["user_id"] == "user-a"
|
||||||
assert value["metadata"]["title"] == "hello"
|
assert value["metadata"]["title"] == "hello"
|
||||||
|
|
||||||
|
|
||||||
def test_filter_returns_user_id_dict():
|
def test_filter_returns_user_id_dict():
|
||||||
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
result = asyncio.run(add_owner_filter(_make_ctx("user-x"), {}))
|
||||||
assert result == {"owner_id": "user-x"}
|
assert result == {"user_id": "user-x"}
|
||||||
|
|
||||||
|
|
||||||
def test_filter_read_write_consistency():
|
def test_filter_read_write_consistency():
|
||||||
value = {}
|
value = {}
|
||||||
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
filter_dict = asyncio.run(add_owner_filter(_make_ctx("user-1"), value))
|
||||||
assert value["metadata"]["owner_id"] == filter_dict["owner_id"]
|
assert value["metadata"]["user_id"] == filter_dict["user_id"]
|
||||||
|
|
||||||
|
|
||||||
def test_different_users_different_filters():
|
def test_different_users_different_filters():
|
||||||
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
f_a = asyncio.run(add_owner_filter(_make_ctx("a"), {}))
|
||||||
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
f_b = asyncio.run(add_owner_filter(_make_ctx("b"), {}))
|
||||||
assert f_a["owner_id"] != f_b["owner_id"]
|
assert f_a["user_id"] != f_b["user_id"]
|
||||||
|
|
||||||
|
|
||||||
def test_filter_overrides_conflicting_user_id():
|
def test_filter_overrides_conflicting_user_id():
|
||||||
"""If value already has a different user_id in metadata, it gets overwritten."""
|
"""If value already has a different user_id in metadata, it gets overwritten."""
|
||||||
value = {"metadata": {"owner_id": "attacker"}}
|
value = {"metadata": {"user_id": "attacker"}}
|
||||||
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
asyncio.run(add_owner_filter(_make_ctx("real-owner"), value))
|
||||||
assert value["metadata"]["owner_id"] == "real-owner"
|
assert value["metadata"]["user_id"] == "real-owner"
|
||||||
|
|
||||||
|
|
||||||
def test_filter_with_empty_metadata():
|
def test_filter_with_empty_metadata():
|
||||||
"""Explicit empty metadata dict is fine."""
|
"""Explicit empty metadata dict is fine."""
|
||||||
value = {"metadata": {}}
|
value = {"metadata": {}}
|
||||||
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
result = asyncio.run(add_owner_filter(_make_ctx("user-z"), value))
|
||||||
assert value["metadata"]["owner_id"] == "user-z"
|
assert value["metadata"]["user_id"] == "user-z"
|
||||||
assert result == {"owner_id": "user-z"}
|
assert result == {"user_id": "user-z"}
|
||||||
|
|
||||||
|
|
||||||
# ── Gateway parity ───────────────────────────────────────────────────────
|
# ── Gateway parity ───────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -9,8 +9,8 @@ These tests bypass the HTTP layer and exercise the storage-layer
|
|||||||
owner filter directly by switching the ``user_context`` contextvar
|
owner filter directly by switching the ``user_context`` contextvar
|
||||||
between two users. The safety property under test is:
|
between two users. The safety property under test is:
|
||||||
|
|
||||||
After a repository write with owner_id=A, a subsequent read with
|
After a repository write with user_id=A, a subsequent read with
|
||||||
owner_id=B must not return the row, and vice versa.
|
user_id=B must not return the row, and vice versa.
|
||||||
|
|
||||||
The HTTP layer is covered by test_auth_middleware.py, which proves
|
The HTTP layer is covered by test_auth_middleware.py, which proves
|
||||||
that a request cookie reaches the ``set_current_user`` call. Together
|
that a request cookie reaches the ``set_current_user`` call. Together
|
||||||
@@ -431,13 +431,13 @@ async def test_repository_without_context_raises(tmp_path):
|
|||||||
await cleanup()
|
await cleanup()
|
||||||
|
|
||||||
|
|
||||||
# ── Escape hatch: explicit owner_id=None bypasses filter (for migration) ──
|
# ── Escape hatch: explicit user_id=None bypasses filter (for migration) ──
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@pytest.mark.no_auto_user
|
@pytest.mark.no_auto_user
|
||||||
async def test_explicit_none_bypasses_filter(tmp_path):
|
async def test_explicit_none_bypasses_filter(tmp_path):
|
||||||
"""Migration scripts pass owner_id=None to see all rows regardless of owner."""
|
"""Migration scripts pass user_id=None to see all rows regardless of owner."""
|
||||||
from deerflow.persistence.engine import get_session_factory
|
from deerflow.persistence.engine import get_session_factory
|
||||||
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
from deerflow.persistence.thread_meta import ThreadMetaRepository
|
||||||
|
|
||||||
@@ -452,14 +452,14 @@ async def test_explicit_none_bypasses_filter(tmp_path):
|
|||||||
await repo.create("t-beta")
|
await repo.create("t-beta")
|
||||||
|
|
||||||
# Migration-style read: no contextvar, explicit None bypass.
|
# Migration-style read: no contextvar, explicit None bypass.
|
||||||
all_rows = await repo.search(owner_id=None)
|
all_rows = await repo.search(user_id=None)
|
||||||
thread_ids = {r["thread_id"] for r in all_rows}
|
thread_ids = {r["thread_id"] for r in all_rows}
|
||||||
assert thread_ids == {"t-alpha", "t-beta"}
|
assert thread_ids == {"t-alpha", "t-beta"}
|
||||||
|
|
||||||
# Explicit get with None does not apply the filter either.
|
# Explicit get with None does not apply the filter either.
|
||||||
row_a = await repo.get("t-alpha", owner_id=None)
|
row_a = await repo.get("t-alpha", user_id=None)
|
||||||
assert row_a is not None
|
assert row_a is not None
|
||||||
row_b = await repo.get("t-beta", owner_id=None)
|
row_b = await repo.get("t-beta", user_id=None)
|
||||||
assert row_b is not None
|
assert row_b is not None
|
||||||
finally:
|
finally:
|
||||||
await cleanup()
|
await cleanup()
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
Tests:
|
Tests:
|
||||||
1. DatabaseConfig property derivation (paths, URLs)
|
1. DatabaseConfig property derivation (paths, URLs)
|
||||||
2. MemoryRunStore CRUD + owner_id filtering
|
2. MemoryRunStore CRUD + user_id filtering
|
||||||
3. Base.to_dict() via inspect mixin
|
3. Base.to_dict() via inspect mixin
|
||||||
4. Engine init/close lifecycle (memory + SQLite)
|
4. Engine init/close lifecycle (memory + SQLite)
|
||||||
5. Postgres missing-dep error message
|
5. Postgres missing-dep error message
|
||||||
@@ -106,17 +106,17 @@ class TestMemoryRunStore:
|
|||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_thread_owner_filter(self, store):
|
async def test_list_by_thread_owner_filter(self, store):
|
||||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
await store.put("r1", thread_id="t1", user_id="alice")
|
||||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
await store.put("r2", thread_id="t1", user_id="bob")
|
||||||
rows = await store.list_by_thread("t1", owner_id="alice")
|
rows = await store.list_by_thread("t1", user_id="alice")
|
||||||
assert len(rows) == 1
|
assert len(rows) == 1
|
||||||
assert rows[0]["owner_id"] == "alice"
|
assert rows[0]["user_id"] == "alice"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_owner_none_returns_all(self, store):
|
async def test_owner_none_returns_all(self, store):
|
||||||
await store.put("r1", thread_id="t1", owner_id="alice")
|
await store.put("r1", thread_id="t1", user_id="alice")
|
||||||
await store.put("r2", thread_id="t1", owner_id="bob")
|
await store.put("r2", thread_id="t1", user_id="bob")
|
||||||
rows = await store.list_by_thread("t1", owner_id=None)
|
rows = await store.list_by_thread("t1", user_id=None)
|
||||||
assert len(rows) == 2
|
assert len(rows) == 2
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
@@ -73,11 +73,11 @@ class TestRunRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_thread_owner_filter(self, tmp_path):
|
async def test_list_by_thread_owner_filter(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||||
rows = await repo.list_by_thread("t1", owner_id="alice")
|
rows = await repo.list_by_thread("t1", user_id="alice")
|
||||||
assert len(rows) == 1
|
assert len(rows) == 1
|
||||||
assert rows[0]["owner_id"] == "alice"
|
assert rows[0]["user_id"] == "alice"
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
@@ -189,8 +189,8 @@ class TestRunRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_owner_none_returns_all(self, tmp_path):
|
async def test_owner_none_returns_all(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
await repo.put("r1", thread_id="t1", user_id="alice")
|
||||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
await repo.put("r2", thread_id="t1", user_id="bob")
|
||||||
rows = await repo.list_by_thread("t1", owner_id=None)
|
rows = await repo.list_by_thread("t1", user_id=None)
|
||||||
assert len(rows) == 2
|
assert len(rows) == 2
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|||||||
@@ -47,7 +47,7 @@ def test_generate_suggestions_parses_and_limits(monkeypatch):
|
|||||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
# Bypass the require_permission decorator (which needs request +
|
# Bypass the require_permission decorator (which needs request +
|
||||||
# thread_meta_repo) — these tests cover the parsing logic.
|
# thread_store) — these tests cover the parsing logic.
|
||||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||||
|
|
||||||
assert result.suggestions == ["Q1", "Q2", "Q3"]
|
assert result.suggestions == ["Q1", "Q2", "Q3"]
|
||||||
@@ -67,7 +67,7 @@ def test_generate_suggestions_parses_list_block_content(monkeypatch):
|
|||||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
# Bypass the require_permission decorator (which needs request +
|
# Bypass the require_permission decorator (which needs request +
|
||||||
# thread_meta_repo) — these tests cover the parsing logic.
|
# thread_store) — these tests cover the parsing logic.
|
||||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||||
|
|
||||||
assert result.suggestions == ["Q1", "Q2"]
|
assert result.suggestions == ["Q1", "Q2"]
|
||||||
@@ -87,7 +87,7 @@ def test_generate_suggestions_parses_output_text_block_content(monkeypatch):
|
|||||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
# Bypass the require_permission decorator (which needs request +
|
# Bypass the require_permission decorator (which needs request +
|
||||||
# thread_meta_repo) — these tests cover the parsing logic.
|
# thread_store) — these tests cover the parsing logic.
|
||||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||||
|
|
||||||
assert result.suggestions == ["Q1", "Q2"]
|
assert result.suggestions == ["Q1", "Q2"]
|
||||||
@@ -104,7 +104,7 @@ def test_generate_suggestions_returns_empty_on_model_error(monkeypatch):
|
|||||||
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
monkeypatch.setattr(suggestions, "create_chat_model", lambda **kwargs: fake_model)
|
||||||
|
|
||||||
# Bypass the require_permission decorator (which needs request +
|
# Bypass the require_permission decorator (which needs request +
|
||||||
# thread_meta_repo) — these tests cover the parsing logic.
|
# thread_store) — these tests cover the parsing logic.
|
||||||
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
result = asyncio.run(suggestions.generate_suggestions.__wrapped__("t1", req, request=None))
|
||||||
|
|
||||||
assert result.suggestions == []
|
assert result.suggestions == []
|
||||||
|
|||||||
@@ -43,8 +43,8 @@ class TestThreadMetaRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_create_with_owner_and_display_name(self, tmp_path):
|
async def test_create_with_owner_and_display_name(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
record = await repo.create("t1", owner_id="user1", display_name="My Thread")
|
record = await repo.create("t1", user_id="user1", display_name="My Thread")
|
||||||
assert record["owner_id"] == "user1"
|
assert record["user_id"] == "user1"
|
||||||
assert record["display_name"] == "My Thread"
|
assert record["display_name"] == "My Thread"
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@@ -61,26 +61,6 @@ class TestThreadMetaRepository:
|
|||||||
assert await repo.get("nonexistent") is None
|
assert await repo.get("nonexistent") is None
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_by_owner(self, tmp_path):
|
|
||||||
repo = await _make_repo(tmp_path)
|
|
||||||
await repo.create("t1", owner_id="user1")
|
|
||||||
await repo.create("t2", owner_id="user1")
|
|
||||||
await repo.create("t3", owner_id="user2")
|
|
||||||
results = await repo.list_by_owner("user1")
|
|
||||||
assert len(results) == 2
|
|
||||||
assert all(r["owner_id"] == "user1" for r in results)
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_list_by_owner_with_limit_and_offset(self, tmp_path):
|
|
||||||
repo = await _make_repo(tmp_path)
|
|
||||||
for i in range(5):
|
|
||||||
await repo.create(f"t{i}", owner_id="user1")
|
|
||||||
results = await repo.list_by_owner("user1", limit=2, offset=1)
|
|
||||||
assert len(results) == 2
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_no_record_allows(self, tmp_path):
|
async def test_check_access_no_record_allows(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
@@ -90,23 +70,23 @@ class TestThreadMetaRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_owner_matches(self, tmp_path):
|
async def test_check_access_owner_matches(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id="user1")
|
await repo.create("t1", user_id="user1")
|
||||||
assert await repo.check_access("t1", "user1") is True
|
assert await repo.check_access("t1", "user1") is True
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_owner_mismatch(self, tmp_path):
|
async def test_check_access_owner_mismatch(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id="user1")
|
await repo.create("t1", user_id="user1")
|
||||||
assert await repo.check_access("t1", "user2") is False
|
assert await repo.check_access("t1", "user2") is False
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
async def test_check_access_no_owner_allows_all(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
# Explicit owner_id=None to bypass the new AUTO default that
|
# Explicit user_id=None to bypass the new AUTO default that
|
||||||
# would otherwise pick up the test user from the autouse fixture.
|
# would otherwise pick up the test user from the autouse fixture.
|
||||||
await repo.create("t1", owner_id=None)
|
await repo.create("t1", user_id=None)
|
||||||
assert await repo.check_access("t1", "anyone") is True
|
assert await repo.check_access("t1", "anyone") is True
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@@ -125,27 +105,27 @@ class TestThreadMetaRepository:
|
|||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
|
async def test_check_access_strict_owner_match_allowed(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id="user1")
|
await repo.create("t1", user_id="user1")
|
||||||
assert await repo.check_access("t1", "user1", require_existing=True) is True
|
assert await repo.check_access("t1", "user1", require_existing=True) is True
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
|
async def test_check_access_strict_owner_mismatch_denied(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id="user1")
|
await repo.create("t1", user_id="user1")
|
||||||
assert await repo.check_access("t1", "user2", require_existing=True) is False
|
assert await repo.check_access("t1", "user2", require_existing=True) is False
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
|
async def test_check_access_strict_null_owner_still_allowed(self, tmp_path):
|
||||||
"""Even in strict mode, a row with NULL owner_id stays shared.
|
"""Even in strict mode, a row with NULL user_id stays shared.
|
||||||
|
|
||||||
The strict flag tightens the *missing row* case, not the *shared
|
The strict flag tightens the *missing row* case, not the *shared
|
||||||
row* case — legacy pre-auth rows that survived a clean migration
|
row* case — legacy pre-auth rows that survived a clean migration
|
||||||
without an owner are still everyone's.
|
without an owner are still everyone's.
|
||||||
"""
|
"""
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
await repo.create("t1", owner_id=None)
|
await repo.create("t1", user_id=None)
|
||||||
assert await repo.check_access("t1", "anyone", require_existing=True) is True
|
assert await repo.check_access("t1", "anyone", require_existing=True) is True
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
|
|||||||
@@ -113,14 +113,8 @@ def test_delete_thread_data_returns_generic_500_error(tmp_path):
|
|||||||
# ── Server-reserved metadata key stripping ──────────────────────────────────
|
# ── Server-reserved metadata key stripping ──────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def test_strip_reserved_metadata_removes_owner_id():
|
|
||||||
"""Client-supplied owner_id is dropped to prevent reflection attacks."""
|
|
||||||
out = threads._strip_reserved_metadata({"owner_id": "victim-id", "title": "ok"})
|
|
||||||
assert out == {"title": "ok"}
|
|
||||||
|
|
||||||
|
|
||||||
def test_strip_reserved_metadata_removes_user_id():
|
def test_strip_reserved_metadata_removes_user_id():
|
||||||
"""user_id is also reserved (defense in depth for any future use)."""
|
"""Client-supplied user_id is dropped to prevent reflection attacks."""
|
||||||
out = threads._strip_reserved_metadata({"user_id": "victim-id", "title": "ok"})
|
out = threads._strip_reserved_metadata({"user_id": "victim-id", "title": "ok"})
|
||||||
assert out == {"title": "ok"}
|
assert out == {"title": "ok"}
|
||||||
|
|
||||||
@@ -136,6 +130,6 @@ def test_strip_reserved_metadata_empty_input():
|
|||||||
assert threads._strip_reserved_metadata({}) == {}
|
assert threads._strip_reserved_metadata({}) == {}
|
||||||
|
|
||||||
|
|
||||||
def test_strip_reserved_metadata_strips_both_simultaneously():
|
def test_strip_reserved_metadata_strips_all_reserved_keys():
|
||||||
out = threads._strip_reserved_metadata({"owner_id": "x", "user_id": "y", "keep": "me"})
|
out = threads._strip_reserved_metadata({"user_id": "x", "keep": "me"})
|
||||||
assert out == {"keep": "me"}
|
assert out == {"keep": "me"}
|
||||||
|
|||||||
Reference in New Issue
Block a user