diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 7711af187..523c64326 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -440,6 +440,12 @@ def _human_input_message(content: str, *, original_content: str | None = None) - return message +def _owner_headers(msg: InboundMessage) -> dict[str, str] | None: + if not msg.owner_user_id: + return None + return create_internal_auth_headers(owner_user_id=msg.owner_user_id) + + def _resolve_slash_skill_command( text: str, available_skills: set[str] | None = None, @@ -914,7 +920,11 @@ class ChannelManager: async def _create_thread(self, client, msg: InboundMessage) -> str: """Create a new thread through Gateway and store the mapping.""" - thread = await client.threads.create() + owner_headers = _owner_headers(msg) + if owner_headers: + thread = await client.threads.create(headers=owner_headers) + else: + thread = await client.threads.create() thread_id = thread["thread_id"] await self._store_thread_id(msg, thread_id) logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id) @@ -969,14 +979,19 @@ class ChannelManager: return logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) + run_kwargs: dict[str, Any] = { + "input": {"messages": [human_message]}, + "config": run_config, + "context": run_context, + "multitask_strategy": "reject", + } + if owner_headers := _owner_headers(msg): + run_kwargs["headers"] = owner_headers try: result = await client.runs.wait( thread_id, assistant_id, - input={"messages": [human_message]}, - config=run_config, - context=run_context, - multitask_strategy="reject", + **run_kwargs, ) except Exception as exc: if _is_thread_busy_error(exc): @@ -1039,16 +1054,21 @@ class ChannelManager: last_published_text = "" last_publish_at = 0.0 stream_error: BaseException | None = None + stream_kwargs: dict[str, Any] = { + "input": {"messages": [human_message]}, + "config": run_config, + "context": run_context, + "stream_mode": ["messages-tuple", "values"], + "multitask_strategy": "reject", + } + if owner_headers := _owner_headers(msg): + stream_kwargs["headers"] = owner_headers try: async for chunk in client.runs.stream( thread_id, assistant_id, - input={"messages": [human_message]}, - config=run_config, - context=run_context, - stream_mode=["messages-tuple", "values"], - multitask_strategy="reject", + **stream_kwargs, ): event = getattr(chunk, "event", "") data = getattr(chunk, "data", None) diff --git a/backend/app/gateway/authz.py b/backend/app/gateway/authz.py index c7cf63858..beec645a3 100644 --- a/backend/app/gateway/authz.py +++ b/backend/app/gateway/authz.py @@ -276,6 +276,11 @@ def require_permission( # strict-deny rather than strict-allow — only an *existing* # row with a *different* user_id triggers 404. if owner_check: + from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE + + if getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE: + return await func(*args, **kwargs) + thread_id = kwargs.get("thread_id") if thread_id is None: raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") diff --git a/backend/app/gateway/internal_auth.py b/backend/app/gateway/internal_auth.py index 3a00a9662..400e997bb 100644 --- a/backend/app/gateway/internal_auth.py +++ b/backend/app/gateway/internal_auth.py @@ -5,10 +5,12 @@ from __future__ import annotations import os import secrets from types import SimpleNamespace +from typing import Any from deerflow.runtime.user_context import DEFAULT_USER_ID INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token" +INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id" INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN" INTERNAL_SYSTEM_ROLE = "internal" @@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str: _INTERNAL_AUTH_TOKEN = _load_internal_auth_token() -def create_internal_auth_headers() -> dict[str, str]: +def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]: """Return headers that authenticate trusted Gateway internal calls.""" - return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN} + headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN} + if owner_user_id: + headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id + return headers def is_valid_internal_auth_token(token: str | None) -> bool: @@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool: def get_internal_user(): """Return the synthetic user used for trusted internal channel calls.""" return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE) + + +def get_trusted_internal_owner_user_id(request: Any) -> str | None: + """Return the owner override for a trusted internal request, if present. + + The header is ignored for normal browser/API callers. It is only honored + after ``AuthMiddleware`` has validated the internal auth token and stamped + the synthetic internal user onto ``request.state.user``. + """ + user = getattr(getattr(request, "state", None), "user", None) + if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE: + return None + + owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME) + if not owner_user_id: + return None + owner_user_id = owner_user_id.strip() + return owner_user_id or None diff --git a/backend/app/gateway/routers/threads.py b/backend/app/gateway/routers/threads.py index fa8de61ff..fd6c05289 100644 --- a/backend/app/gateway/routers/threads.py +++ b/backend/app/gateway/routers/threads.py @@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator from app.gateway.authz import require_permission from app.gateway.deps import get_checkpointer +from app.gateway.internal_auth import get_trusted_internal_owner_user_id from app.gateway.utils import sanitize_log_param from deerflow.config.paths import Paths, get_paths from deerflow.runtime import serialize_channel_values @@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe thread_store = get_thread_store(request) thread_id = body.thread_id or str(uuid.uuid4()) now = now_iso() + thread_owner_user_id = get_trusted_internal_owner_user_id(request) + thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {} # ``body.metadata`` is already stripped of server-reserved keys by # ``ThreadCreateRequest._strip_reserved`` — see the model definition. # Idempotency: return existing record when already present - existing_record = await thread_store.get(thread_id) + existing_record = await thread_store.get(thread_id, **thread_owner_kwargs) + if existing_record is None and thread_owner_user_id: + unscoped_record = await thread_store.get(thread_id, user_id=None) + if unscoped_record is not None: + if unscoped_record.get("user_id") != thread_owner_user_id: + await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None) + existing_record = await thread_store.get(thread_id, **thread_owner_kwargs) if existing_record is not None: return ThreadResponse( thread_id=thread_id, @@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe await thread_store.create( thread_id, assistant_id=getattr(body, "assistant_id", None), + **thread_owner_kwargs, metadata=body.metadata, ) except Exception: diff --git a/backend/app/gateway/services.py b/backend/app/gateway/services.py index 2c5c01e61..015f74398 100644 --- a/backend/app/gateway/services.py +++ b/backend/app/gateway/services.py @@ -12,6 +12,7 @@ import json import logging import re from collections.abc import Mapping +from types import SimpleNamespace from typing import Any from fastapi import HTTPException, Request @@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage from langchain_core.messages.utils import convert_to_messages from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge -from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE +from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id from app.gateway.utils import sanitize_log_param from deerflow.config.app_config import get_app_config from deerflow.runtime import ( @@ -35,6 +36,7 @@ from deerflow.runtime import ( run_agent, ) from deerflow.runtime.runs.naming import resolve_root_run_name +from deerflow.runtime.user_context import reset_current_user, set_current_user logger = logging.getLogger(__name__) @@ -315,72 +317,85 @@ async def start_run( detail=f"Model {model_name!r} is not in the configured model allowlist", ) + owner_user_id = get_trusted_internal_owner_user_id(request) + owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None try: - record = await run_mgr.create_or_reject( - thread_id, - body.assistant_id, - on_disconnect=disconnect, - metadata=body.metadata or {}, - kwargs={"input": body.input, "config": body.config}, - multitask_strategy=body.multitask_strategy, - model_name=model_name, - ) - except ConflictError as exc: - raise HTTPException(status_code=409, detail=str(exc)) from exc - except UnsupportedStrategyError as exc: - raise HTTPException(status_code=501, detail=str(exc)) from exc - - # Upsert thread metadata so the thread appears in /threads/search, - # even for threads that were never explicitly created via POST /threads - # (e.g. stateless runs). - try: - existing = await run_ctx.thread_store.get(thread_id) - if existing is None: - await run_ctx.thread_store.create( + try: + record = await run_mgr.create_or_reject( thread_id, - assistant_id=body.assistant_id, - metadata=body.metadata, + body.assistant_id, + on_disconnect=disconnect, + metadata=body.metadata or {}, + kwargs={"input": body.input, "config": body.config}, + multitask_strategy=body.multitask_strategy, + model_name=model_name, + user_id=owner_user_id, ) - else: - await run_ctx.thread_store.update_status(thread_id, "running") - except Exception: - logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) + except ConflictError as exc: + raise HTTPException(status_code=409, detail=str(exc)) from exc + except UnsupportedStrategyError as exc: + raise HTTPException(status_code=501, detail=str(exc)) from exc - agent_factory = resolve_agent_factory(body.assistant_id) - graph_input = normalize_input(body.input) - config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) + # Upsert thread metadata so the thread appears in /threads/search, + # even for threads that were never explicitly created via POST /threads + # (e.g. stateless runs). + try: + existing = await run_ctx.thread_store.get(thread_id) + if existing is None and owner_user_id: + unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None) + if unscoped_existing is not None: + if unscoped_existing.get("user_id") != owner_user_id: + await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None) + existing = await run_ctx.thread_store.get(thread_id) + if existing is None: + await run_ctx.thread_store.create( + thread_id, + assistant_id=body.assistant_id, + metadata=body.metadata, + ) + else: + await run_ctx.thread_store.update_status(thread_id, "running") + except Exception: + logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) - # Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``. - # The ``context`` field is a custom extension for the langgraph-compat layer - # that carries agent configuration (model_name, thinking_enabled, etc.). - # Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored. - merge_run_context_overrides(config, getattr(body, "context", None)) - inject_authenticated_user_context(config, request) + agent_factory = resolve_agent_factory(body.assistant_id) + graph_input = normalize_input(body.input) + config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) - stream_modes = normalize_stream_modes(body.stream_mode) + # Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``. + # The ``context`` field is a custom extension for the langgraph-compat layer + # that carries agent configuration (model_name, thinking_enabled, etc.). + # Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored. + merge_run_context_overrides(config, getattr(body, "context", None)) + inject_authenticated_user_context(config, request) - task = asyncio.create_task( - run_agent( - bridge, - run_mgr, - record, - ctx=run_ctx, - agent_factory=agent_factory, - graph_input=graph_input, - config=config, - stream_modes=stream_modes, - stream_subgraphs=body.stream_subgraphs, - interrupt_before=body.interrupt_before, - interrupt_after=body.interrupt_after, + stream_modes = normalize_stream_modes(body.stream_mode) + + task = asyncio.create_task( + run_agent( + bridge, + run_mgr, + record, + ctx=run_ctx, + agent_factory=agent_factory, + graph_input=graph_input, + config=config, + stream_modes=stream_modes, + stream_subgraphs=body.stream_subgraphs, + interrupt_before=body.interrupt_before, + interrupt_after=body.interrupt_after, + ) ) - ) - record.task = task + record.task = task - # Title sync is handled by worker.py's finally block which reads the - # title from the checkpoint and calls thread_store.update_display_name - # after the run completes. + # Title sync is handled by worker.py's finally block which reads the + # title from the checkpoint and calls thread_store.update_display_name + # after the run completes. - return record + return record + finally: + if owner_context_token is not None: + reset_current_user(owner_context_token) async def sse_consumer( diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/base.py b/backend/packages/harness/deerflow/persistence/thread_meta/base.py index ed55ade8e..4207b4daa 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/base.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/base.py @@ -71,6 +71,15 @@ class ThreadMetaStore(abc.ABC): """ pass + @abc.abstractmethod + async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + """Move a thread metadata row to a new owner. + + Intended for trusted internal repair/migration paths. No-op if the + row does not exist or the caller fails the owner check. + """ + pass + @abc.abstractmethod async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: """Check if ``user_id`` has access to ``thread_id``.""" diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py index 4f642a938..b17d994f8 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/memory.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/memory.py @@ -127,6 +127,14 @@ class MemoryThreadMetaStore(ThreadMetaStore): record["updated_at"] = now_iso() await self._store.aput(THREADS_NS, thread_id, record) + async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: + record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner") + if record is None: + return + record["user_id"] = owner_user_id + record["updated_at"] = now_iso() + await self._store.aput(THREADS_NS, thread_id, record) + async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete") if record is None: diff --git a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py index 930128087..a5e7f51c5 100644 --- a/backend/packages/harness/deerflow/persistence/thread_meta/sql.py +++ b/backend/packages/harness/deerflow/persistence/thread_meta/sql.py @@ -211,6 +211,21 @@ class ThreadMetaRepository(ThreadMetaStore): row.updated_at = datetime.now(UTC) await session.commit() + async def update_owner( + self, + thread_id: str, + owner_user_id: str, + *, + user_id: str | None | _AutoSentinel = AUTO, + ) -> None: + """Move a thread metadata row to ``owner_user_id``.""" + resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner") + async with self._sf() as session: + if not await self._check_ownership(session, thread_id, resolved_user_id): + return + await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC))) + await session.commit() + async def delete( self, thread_id: str, diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index ef45852fb..9a9082fb7 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -83,6 +83,7 @@ class RunRecord: multitask_strategy: str = "reject" metadata: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict) + user_id: str | None = None created_at: str = "" updated_at: str = "" task: asyncio.Task | None = field(default=None, repr=False) @@ -124,7 +125,7 @@ class RunManager: @staticmethod def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]: - return { + payload = { "thread_id": record.thread_id, "assistant_id": record.assistant_id, "status": record.status.value, @@ -135,6 +136,9 @@ class RunManager: "created_at": record.created_at, "model_name": record.model_name, } + if record.user_id is not None: + payload["user_id"] = record.user_id + return payload async def _call_store_with_retry( self, @@ -241,6 +245,7 @@ class RunManager: kwargs=row.get("kwargs") or {}, created_at=row.get("created_at") or "", updated_at=row.get("updated_at") or "", + user_id=row.get("user_id"), error=row.get("error"), model_name=row.get("model_name"), store_only=True, @@ -320,6 +325,7 @@ class RunManager: metadata: dict | None = None, kwargs: dict | None = None, multitask_strategy: str = "reject", + user_id: str | None = None, ) -> RunRecord: """Create a new pending run and register it.""" run_id = str(uuid.uuid4()) @@ -333,6 +339,7 @@ class RunManager: multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, + user_id=user_id, created_at=now, updated_at=now, ) @@ -504,6 +511,7 @@ class RunManager: kwargs: dict | None = None, multitask_strategy: str = "reject", model_name: str | None = None, + user_id: str | None = None, ) -> RunRecord: """Atomically check for inflight runs and create a new one. @@ -546,6 +554,7 @@ class RunManager: multitask_strategy=multitask_strategy, metadata=metadata or {}, kwargs=kwargs or {}, + user_id=user_id, created_at=now, updated_at=now, model_name=model_name, diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 04d3614da..8d00de577 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -2388,6 +2388,7 @@ class TestResolveRunParamsUserId: class TestChannelManagerConnectionRouting: def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path): from app.channels.manager import ChannelManager + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME from deerflow.persistence.engine import close_engine async def go(): @@ -2453,6 +2454,16 @@ class TestChannelManagerConnectionRouting: assert second_context["user_id"] == "bob" assert second_context["channel_user_id"] == "U-bob" + first_create_headers = mock_client.threads.create.call_args_list[0].kwargs["headers"] + second_create_headers = mock_client.threads.create.call_args_list[1].kwargs["headers"] + assert first_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice" + assert second_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob" + + first_run_headers = mock_client.runs.wait.call_args_list[0].kwargs["headers"] + second_run_headers = mock_client.runs.wait.call_args_list[1].kwargs["headers"] + assert first_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice" + assert second_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob" + try: _run(go()) finally: diff --git a/backend/tests/test_gateway_services.py b/backend/tests/test_gateway_services.py index d62ed9371..78d05e916 100644 --- a/backend/tests/test_gateway_services.py +++ b/backend/tests/test_gateway_services.py @@ -474,6 +474,83 @@ def test_inject_authenticated_user_context_skips_internal_role(): assert config["context"]["user_id"] == "channel-user-7" +def test_start_run_uses_internal_owner_header_for_persistence(): + import asyncio + from types import SimpleNamespace + from unittest.mock import patch + + from langgraph.checkpoint.memory import InMemorySaver + from langgraph.store.memory import InMemoryStore + + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE + from app.gateway.services import start_run + from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore + from deerflow.runtime import RunManager + from deerflow.runtime.runs.store.memory import MemoryRunStore + from deerflow.runtime.user_context import get_effective_user_id + + async def _scenario(): + run_store = MemoryRunStore() + thread_store = MemoryThreadMetaStore(InMemoryStore()) + await thread_store.create("channel-thread", user_id="default", metadata={"legacy": True}) + run_manager = RunManager(store=run_store) + state = SimpleNamespace( + stream_bridge=SimpleNamespace(), + run_manager=run_manager, + checkpointer=InMemorySaver(), + store=InMemoryStore(), + run_event_store=SimpleNamespace(), + run_events_config=None, + thread_store=thread_store, + ) + request = SimpleNamespace( + headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"}, + state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)), + app=SimpleNamespace(state=state), + ) + body = SimpleNamespace( + assistant_id="lead_agent", + input={"messages": [{"role": "human", "content": "hi"}]}, + metadata={}, + config=None, + context=None, + on_disconnect="cancel", + multitask_strategy="reject", + stream_mode=None, + stream_subgraphs=False, + interrupt_before=None, + interrupt_after=None, + ) + task_context: dict[str, str] = {} + + async def fake_run_agent(*args, **kwargs): + task_context["user_id"] = get_effective_user_id() + + with ( + patch("app.gateway.services.resolve_agent_factory", return_value=object()), + patch("app.gateway.services.run_agent", side_effect=fake_run_agent), + ): + record = await start_run(body, "channel-thread", request) + await record.task + + owner_run = await run_store.get(record.run_id, user_id="owner-1") + default_run = await run_store.get(record.run_id, user_id="default") + owner_thread = await thread_store.get("channel-thread", user_id="owner-1") + default_thread = await thread_store.get("channel-thread", user_id="default") + return owner_run, default_run, owner_thread, default_thread, task_context + + owner_run, default_run, owner_thread, default_thread, task_context = asyncio.run(_scenario()) + + assert owner_run is not None + assert owner_run["user_id"] == "owner-1" + assert default_run is None + assert owner_thread is not None + assert owner_thread["user_id"] == "owner-1" + assert owner_thread["metadata"] == {"legacy": True} + assert default_thread is None + assert task_context["user_id"] == "owner-1" + + # --------------------------------------------------------------------------- # build_run_config — context / configurable precedence (LangGraph >= 0.6.0) # --------------------------------------------------------------------------- diff --git a/backend/tests/test_internal_auth.py b/backend/tests/test_internal_auth.py index 7e56e1dd0..478b00d83 100644 --- a/backend/tests/test_internal_auth.py +++ b/backend/tests/test_internal_auth.py @@ -33,3 +33,18 @@ def test_internal_auth_generates_process_local_fallback(monkeypatch): assert reloaded.is_valid_internal_auth_token(token) is True finally: importlib.reload(reloaded) + + +def test_internal_auth_headers_can_carry_owner_user_id(monkeypatch): + import app.gateway.internal_auth as internal_auth + + monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token") + reloaded = importlib.reload(internal_auth) + try: + headers = reloaded.create_internal_auth_headers(owner_user_id="owner-1") + + assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token" + assert headers[reloaded.INTERNAL_OWNER_USER_ID_HEADER_NAME] == "owner-1" + finally: + monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False) + importlib.reload(reloaded) diff --git a/backend/tests/test_thread_meta_repo.py b/backend/tests/test_thread_meta_repo.py index 1cef3752b..c6fff8868 100644 --- a/backend/tests/test_thread_meta_repo.py +++ b/backend/tests/test_thread_meta_repo.py @@ -137,6 +137,19 @@ class TestThreadMetaRepository: async def test_update_metadata_nonexistent_is_noop(self, repo): await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise + @pytest.mark.anyio + async def test_update_owner_with_bypass_moves_row(self, repo): + await repo.create("t1", user_id="default", metadata={"source": "channel"}) + await repo.update_owner("t1", "owner-1", user_id=None) + + owner_row = await repo.get("t1", user_id="owner-1") + default_row = await repo.get("t1", user_id="default") + + assert owner_row is not None + assert owner_row["user_id"] == "owner-1" + assert owner_row["metadata"] == {"source": "channel"} + assert default_row is None + # --- search with metadata filter (SQL push-down) --- @pytest.mark.anyio diff --git a/backend/tests/test_threads_router.py b/backend/tests/test_threads_router.py index f6f6adcef..74e4c7a50 100644 --- a/backend/tests/test_threads_router.py +++ b/backend/tests/test_threads_router.py @@ -1,4 +1,5 @@ import re +from types import SimpleNamespace from unittest.mock import patch import pytest @@ -218,6 +219,37 @@ def test_create_thread_returns_iso_timestamps() -> None: assert body["created_at"] == body["updated_at"] +def test_internal_owner_header_assigns_thread_to_owner() -> None: + import asyncio + + from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE + + store = InMemoryStore() + checkpointer = InMemorySaver() + thread_store = MemoryThreadMetaStore(store) + request = SimpleNamespace( + headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"}, + state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)), + app=SimpleNamespace(state=SimpleNamespace(checkpointer=checkpointer, thread_store=thread_store)), + ) + + async def _scenario(): + response = await threads.create_thread( + threads.ThreadCreateRequest(thread_id="channel-thread", metadata={}), + request, + ) + owner_row = await thread_store.get("channel-thread", user_id="owner-1") + internal_row = await thread_store.get("channel-thread", user_id="default") + return response, owner_row, internal_row + + response, owner_row, internal_row = asyncio.run(_scenario()) + + assert response.thread_id == "channel-thread" + assert owner_row is not None + assert owner_row["user_id"] == "owner-1" + assert internal_row is None + + def test_get_thread_returns_iso_for_legacy_unix_record() -> None: """A thread record written by older versions stores ``time.time()`` floats. ``get_thread`` must transparently surface them as ISO so the