Make channel threads visible to connection owners

This commit is contained in:
taohe
2026-06-11 15:40:49 +08:00
parent 92f562920d
commit 09872af36c
14 changed files with 333 additions and 71 deletions
+30 -10
View File
@@ -440,6 +440,12 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
return message return message
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
if not msg.owner_user_id:
return None
return create_internal_auth_headers(owner_user_id=msg.owner_user_id)
def _resolve_slash_skill_command( def _resolve_slash_skill_command(
text: str, text: str,
available_skills: set[str] | None = None, available_skills: set[str] | None = None,
@@ -914,7 +920,11 @@ class ChannelManager:
async def _create_thread(self, client, msg: InboundMessage) -> str: async def _create_thread(self, client, msg: InboundMessage) -> str:
"""Create a new thread through Gateway and store the mapping.""" """Create a new thread through Gateway and store the mapping."""
thread = await client.threads.create() owner_headers = _owner_headers(msg)
if owner_headers:
thread = await client.threads.create(headers=owner_headers)
else:
thread = await client.threads.create()
thread_id = thread["thread_id"] thread_id = thread["thread_id"]
await self._store_thread_id(msg, thread_id) await self._store_thread_id(msg, thread_id)
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id) logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
@@ -969,14 +979,19 @@ class ChannelManager:
return return
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
run_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
run_kwargs["headers"] = owner_headers
try: try:
result = await client.runs.wait( result = await client.runs.wait(
thread_id, thread_id,
assistant_id, assistant_id,
input={"messages": [human_message]}, **run_kwargs,
config=run_config,
context=run_context,
multitask_strategy="reject",
) )
except Exception as exc: except Exception as exc:
if _is_thread_busy_error(exc): if _is_thread_busy_error(exc):
@@ -1039,16 +1054,21 @@ class ChannelManager:
last_published_text = "" last_published_text = ""
last_publish_at = 0.0 last_publish_at = 0.0
stream_error: BaseException | None = None stream_error: BaseException | None = None
stream_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]},
"config": run_config,
"context": run_context,
"stream_mode": ["messages-tuple", "values"],
"multitask_strategy": "reject",
}
if owner_headers := _owner_headers(msg):
stream_kwargs["headers"] = owner_headers
try: try:
async for chunk in client.runs.stream( async for chunk in client.runs.stream(
thread_id, thread_id,
assistant_id, assistant_id,
input={"messages": [human_message]}, **stream_kwargs,
config=run_config,
context=run_context,
stream_mode=["messages-tuple", "values"],
multitask_strategy="reject",
): ):
event = getattr(chunk, "event", "") event = getattr(chunk, "event", "")
data = getattr(chunk, "data", None) data = getattr(chunk, "data", None)
+5
View File
@@ -276,6 +276,11 @@ def require_permission(
# strict-deny rather than strict-allow — only an *existing* # strict-deny rather than strict-allow — only an *existing*
# row with a *different* user_id triggers 404. # row with a *different* user_id triggers 404.
if owner_check: if owner_check:
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
if getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
return await func(*args, **kwargs)
thread_id = kwargs.get("thread_id") thread_id = kwargs.get("thread_id")
if thread_id is None: if thread_id is None:
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter") raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
+25 -2
View File
@@ -5,10 +5,12 @@ from __future__ import annotations
import os import os
import secrets import secrets
from types import SimpleNamespace from types import SimpleNamespace
from typing import Any
from deerflow.runtime.user_context import DEFAULT_USER_ID from deerflow.runtime.user_context import DEFAULT_USER_ID
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token" INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN" INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
INTERNAL_SYSTEM_ROLE = "internal" INTERNAL_SYSTEM_ROLE = "internal"
@@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str:
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token() _INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
def create_internal_auth_headers() -> dict[str, str]: def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
"""Return headers that authenticate trusted Gateway internal calls.""" """Return headers that authenticate trusted Gateway internal calls."""
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN} headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
if owner_user_id:
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
return headers
def is_valid_internal_auth_token(token: str | None) -> bool: def is_valid_internal_auth_token(token: str | None) -> bool:
@@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
def get_internal_user(): def get_internal_user():
"""Return the synthetic user used for trusted internal channel calls.""" """Return the synthetic user used for trusted internal channel calls."""
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE) return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
"""Return the owner override for a trusted internal request, if present.
The header is ignored for normal browser/API callers. It is only honored
after ``AuthMiddleware`` has validated the internal auth token and stamped
the synthetic internal user onto ``request.state.user``.
"""
user = getattr(getattr(request, "state", None), "user", None)
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
return None
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
if not owner_user_id:
return None
owner_user_id = owner_user_id.strip()
return owner_user_id or None
+11 -1
View File
@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator
from app.gateway.authz import require_permission from app.gateway.authz import require_permission
from app.gateway.deps import get_checkpointer from app.gateway.deps import get_checkpointer
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
from app.gateway.utils import sanitize_log_param from app.gateway.utils import sanitize_log_param
from deerflow.config.paths import Paths, get_paths from deerflow.config.paths import Paths, get_paths
from deerflow.runtime import serialize_channel_values from deerflow.runtime import serialize_channel_values
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
thread_store = get_thread_store(request) thread_store = get_thread_store(request)
thread_id = body.thread_id or str(uuid.uuid4()) thread_id = body.thread_id or str(uuid.uuid4())
now = now_iso() now = now_iso()
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
# ``body.metadata`` is already stripped of server-reserved keys by # ``body.metadata`` is already stripped of server-reserved keys by
# ``ThreadCreateRequest._strip_reserved`` — see the model definition. # ``ThreadCreateRequest._strip_reserved`` — see the model definition.
# Idempotency: return existing record when already present # Idempotency: return existing record when already present
existing_record = await thread_store.get(thread_id) existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
if existing_record is None and thread_owner_user_id:
unscoped_record = await thread_store.get(thread_id, user_id=None)
if unscoped_record is not None:
if unscoped_record.get("user_id") != thread_owner_user_id:
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
if existing_record is not None: if existing_record is not None:
return ThreadResponse( return ThreadResponse(
thread_id=thread_id, thread_id=thread_id,
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
await thread_store.create( await thread_store.create(
thread_id, thread_id,
assistant_id=getattr(body, "assistant_id", None), assistant_id=getattr(body, "assistant_id", None),
**thread_owner_kwargs,
metadata=body.metadata, metadata=body.metadata,
) )
except Exception: except Exception:
+72 -57
View File
@@ -12,6 +12,7 @@ import json
import logging import logging
import re import re
from collections.abc import Mapping from collections.abc import Mapping
from types import SimpleNamespace
from typing import Any from typing import Any
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
@@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage
from langchain_core.messages.utils import convert_to_messages from langchain_core.messages.utils import convert_to_messages
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
from app.gateway.utils import sanitize_log_param from app.gateway.utils import sanitize_log_param
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
from deerflow.runtime import ( from deerflow.runtime import (
@@ -35,6 +36,7 @@ from deerflow.runtime import (
run_agent, run_agent,
) )
from deerflow.runtime.runs.naming import resolve_root_run_name from deerflow.runtime.runs.naming import resolve_root_run_name
from deerflow.runtime.user_context import reset_current_user, set_current_user
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -315,72 +317,85 @@ async def start_run(
detail=f"Model {model_name!r} is not in the configured model allowlist", detail=f"Model {model_name!r} is not in the configured model allowlist",
) )
owner_user_id = get_trusted_internal_owner_user_id(request)
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
try: try:
record = await run_mgr.create_or_reject( try:
thread_id, record = await run_mgr.create_or_reject(
body.assistant_id,
on_disconnect=disconnect,
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
model_name=model_name,
)
except ConflictError as exc:
raise HTTPException(status_code=409, detail=str(exc)) from exc
except UnsupportedStrategyError as exc:
raise HTTPException(status_code=501, detail=str(exc)) from exc
# Upsert thread metadata so the thread appears in /threads/search,
# even for threads that were never explicitly created via POST /threads
# (e.g. stateless runs).
try:
existing = await run_ctx.thread_store.get(thread_id)
if existing is None:
await run_ctx.thread_store.create(
thread_id, thread_id,
assistant_id=body.assistant_id, body.assistant_id,
metadata=body.metadata, on_disconnect=disconnect,
metadata=body.metadata or {},
kwargs={"input": body.input, "config": body.config},
multitask_strategy=body.multitask_strategy,
model_name=model_name,
user_id=owner_user_id,
) )
else: except ConflictError as exc:
await run_ctx.thread_store.update_status(thread_id, "running") raise HTTPException(status_code=409, detail=str(exc)) from exc
except Exception: except UnsupportedStrategyError as exc:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id)) raise HTTPException(status_code=501, detail=str(exc)) from exc
agent_factory = resolve_agent_factory(body.assistant_id) # Upsert thread metadata so the thread appears in /threads/search,
graph_input = normalize_input(body.input) # even for threads that were never explicitly created via POST /threads
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id) # (e.g. stateless runs).
try:
existing = await run_ctx.thread_store.get(thread_id)
if existing is None and owner_user_id:
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
if unscoped_existing is not None:
if unscoped_existing.get("user_id") != owner_user_id:
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
existing = await run_ctx.thread_store.get(thread_id)
if existing is None:
await run_ctx.thread_store.create(
thread_id,
assistant_id=body.assistant_id,
metadata=body.metadata,
)
else:
await run_ctx.thread_store.update_status(thread_id, "running")
except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``. agent_factory = resolve_agent_factory(body.assistant_id)
# The ``context`` field is a custom extension for the langgraph-compat layer graph_input = normalize_input(body.input)
# that carries agent configuration (model_name, thinking_enabled, etc.). config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
stream_modes = normalize_stream_modes(body.stream_mode) # Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
# The ``context`` field is a custom extension for the langgraph-compat layer
# that carries agent configuration (model_name, thinking_enabled, etc.).
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
merge_run_context_overrides(config, getattr(body, "context", None))
inject_authenticated_user_context(config, request)
task = asyncio.create_task( stream_modes = normalize_stream_modes(body.stream_mode)
run_agent(
bridge, task = asyncio.create_task(
run_mgr, run_agent(
record, bridge,
ctx=run_ctx, run_mgr,
agent_factory=agent_factory, record,
graph_input=graph_input, ctx=run_ctx,
config=config, agent_factory=agent_factory,
stream_modes=stream_modes, graph_input=graph_input,
stream_subgraphs=body.stream_subgraphs, config=config,
interrupt_before=body.interrupt_before, stream_modes=stream_modes,
interrupt_after=body.interrupt_after, stream_subgraphs=body.stream_subgraphs,
interrupt_before=body.interrupt_before,
interrupt_after=body.interrupt_after,
)
) )
) record.task = task
record.task = task
# Title sync is handled by worker.py's finally block which reads the # Title sync is handled by worker.py's finally block which reads the
# title from the checkpoint and calls thread_store.update_display_name # title from the checkpoint and calls thread_store.update_display_name
# after the run completes. # after the run completes.
return record return record
finally:
if owner_context_token is not None:
reset_current_user(owner_context_token)
async def sse_consumer( async def sse_consumer(
@@ -71,6 +71,15 @@ class ThreadMetaStore(abc.ABC):
""" """
pass pass
@abc.abstractmethod
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
"""Move a thread metadata row to a new owner.
Intended for trusted internal repair/migration paths. No-op if the
row does not exist or the caller fails the owner check.
"""
pass
@abc.abstractmethod @abc.abstractmethod
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool: async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
"""Check if ``user_id`` has access to ``thread_id``.""" """Check if ``user_id`` has access to ``thread_id``."""
@@ -127,6 +127,14 @@ class MemoryThreadMetaStore(ThreadMetaStore):
record["updated_at"] = now_iso() record["updated_at"] = now_iso()
await self._store.aput(THREADS_NS, thread_id, record) await self._store.aput(THREADS_NS, thread_id, record)
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner")
if record is None:
return
record["user_id"] = owner_user_id
record["updated_at"] = now_iso()
await self._store.aput(THREADS_NS, thread_id, record)
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None: async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete") record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
if record is None: if record is None:
@@ -211,6 +211,21 @@ class ThreadMetaRepository(ThreadMetaStore):
row.updated_at = datetime.now(UTC) row.updated_at = datetime.now(UTC)
await session.commit() await session.commit()
async def update_owner(
self,
thread_id: str,
owner_user_id: str,
*,
user_id: str | None | _AutoSentinel = AUTO,
) -> None:
"""Move a thread metadata row to ``owner_user_id``."""
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner")
async with self._sf() as session:
if not await self._check_ownership(session, thread_id, resolved_user_id):
return
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC)))
await session.commit()
async def delete( async def delete(
self, self,
thread_id: str, thread_id: str,
@@ -83,6 +83,7 @@ class RunRecord:
multitask_strategy: str = "reject" multitask_strategy: str = "reject"
metadata: dict = field(default_factory=dict) metadata: dict = field(default_factory=dict)
kwargs: dict = field(default_factory=dict) kwargs: dict = field(default_factory=dict)
user_id: str | None = None
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
task: asyncio.Task | None = field(default=None, repr=False) task: asyncio.Task | None = field(default=None, repr=False)
@@ -124,7 +125,7 @@ class RunManager:
@staticmethod @staticmethod
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]: def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
return { payload = {
"thread_id": record.thread_id, "thread_id": record.thread_id,
"assistant_id": record.assistant_id, "assistant_id": record.assistant_id,
"status": record.status.value, "status": record.status.value,
@@ -135,6 +136,9 @@ class RunManager:
"created_at": record.created_at, "created_at": record.created_at,
"model_name": record.model_name, "model_name": record.model_name,
} }
if record.user_id is not None:
payload["user_id"] = record.user_id
return payload
async def _call_store_with_retry( async def _call_store_with_retry(
self, self,
@@ -241,6 +245,7 @@ class RunManager:
kwargs=row.get("kwargs") or {}, kwargs=row.get("kwargs") or {},
created_at=row.get("created_at") or "", created_at=row.get("created_at") or "",
updated_at=row.get("updated_at") or "", updated_at=row.get("updated_at") or "",
user_id=row.get("user_id"),
error=row.get("error"), error=row.get("error"),
model_name=row.get("model_name"), model_name=row.get("model_name"),
store_only=True, store_only=True,
@@ -320,6 +325,7 @@ class RunManager:
metadata: dict | None = None, metadata: dict | None = None,
kwargs: dict | None = None, kwargs: dict | None = None,
multitask_strategy: str = "reject", multitask_strategy: str = "reject",
user_id: str | None = None,
) -> RunRecord: ) -> RunRecord:
"""Create a new pending run and register it.""" """Create a new pending run and register it."""
run_id = str(uuid.uuid4()) run_id = str(uuid.uuid4())
@@ -333,6 +339,7 @@ class RunManager:
multitask_strategy=multitask_strategy, multitask_strategy=multitask_strategy,
metadata=metadata or {}, metadata=metadata or {},
kwargs=kwargs or {}, kwargs=kwargs or {},
user_id=user_id,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
) )
@@ -504,6 +511,7 @@ class RunManager:
kwargs: dict | None = None, kwargs: dict | None = None,
multitask_strategy: str = "reject", multitask_strategy: str = "reject",
model_name: str | None = None, model_name: str | None = None,
user_id: str | None = None,
) -> RunRecord: ) -> RunRecord:
"""Atomically check for inflight runs and create a new one. """Atomically check for inflight runs and create a new one.
@@ -546,6 +554,7 @@ class RunManager:
multitask_strategy=multitask_strategy, multitask_strategy=multitask_strategy,
metadata=metadata or {}, metadata=metadata or {},
kwargs=kwargs or {}, kwargs=kwargs or {},
user_id=user_id,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
model_name=model_name, model_name=model_name,
+11
View File
@@ -2388,6 +2388,7 @@ class TestResolveRunParamsUserId:
class TestChannelManagerConnectionRouting: class TestChannelManagerConnectionRouting:
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path): def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path):
from app.channels.manager import ChannelManager from app.channels.manager import ChannelManager
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
from deerflow.persistence.engine import close_engine from deerflow.persistence.engine import close_engine
async def go(): async def go():
@@ -2453,6 +2454,16 @@ class TestChannelManagerConnectionRouting:
assert second_context["user_id"] == "bob" assert second_context["user_id"] == "bob"
assert second_context["channel_user_id"] == "U-bob" assert second_context["channel_user_id"] == "U-bob"
first_create_headers = mock_client.threads.create.call_args_list[0].kwargs["headers"]
second_create_headers = mock_client.threads.create.call_args_list[1].kwargs["headers"]
assert first_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
assert second_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
first_run_headers = mock_client.runs.wait.call_args_list[0].kwargs["headers"]
second_run_headers = mock_client.runs.wait.call_args_list[1].kwargs["headers"]
assert first_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
assert second_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
try: try:
_run(go()) _run(go())
finally: finally:
+77
View File
@@ -474,6 +474,83 @@ def test_inject_authenticated_user_context_skips_internal_role():
assert config["context"]["user_id"] == "channel-user-7" assert config["context"]["user_id"] == "channel-user-7"
def test_start_run_uses_internal_owner_header_for_persistence():
import asyncio
from types import SimpleNamespace
from unittest.mock import patch
from langgraph.checkpoint.memory import InMemorySaver
from langgraph.store.memory import InMemoryStore
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
from app.gateway.services import start_run
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
from deerflow.runtime import RunManager
from deerflow.runtime.runs.store.memory import MemoryRunStore
from deerflow.runtime.user_context import get_effective_user_id
async def _scenario():
run_store = MemoryRunStore()
thread_store = MemoryThreadMetaStore(InMemoryStore())
await thread_store.create("channel-thread", user_id="default", metadata={"legacy": True})
run_manager = RunManager(store=run_store)
state = SimpleNamespace(
stream_bridge=SimpleNamespace(),
run_manager=run_manager,
checkpointer=InMemorySaver(),
store=InMemoryStore(),
run_event_store=SimpleNamespace(),
run_events_config=None,
thread_store=thread_store,
)
request = SimpleNamespace(
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
app=SimpleNamespace(state=state),
)
body = SimpleNamespace(
assistant_id="lead_agent",
input={"messages": [{"role": "human", "content": "hi"}]},
metadata={},
config=None,
context=None,
on_disconnect="cancel",
multitask_strategy="reject",
stream_mode=None,
stream_subgraphs=False,
interrupt_before=None,
interrupt_after=None,
)
task_context: dict[str, str] = {}
async def fake_run_agent(*args, **kwargs):
task_context["user_id"] = get_effective_user_id()
with (
patch("app.gateway.services.resolve_agent_factory", return_value=object()),
patch("app.gateway.services.run_agent", side_effect=fake_run_agent),
):
record = await start_run(body, "channel-thread", request)
await record.task
owner_run = await run_store.get(record.run_id, user_id="owner-1")
default_run = await run_store.get(record.run_id, user_id="default")
owner_thread = await thread_store.get("channel-thread", user_id="owner-1")
default_thread = await thread_store.get("channel-thread", user_id="default")
return owner_run, default_run, owner_thread, default_thread, task_context
owner_run, default_run, owner_thread, default_thread, task_context = asyncio.run(_scenario())
assert owner_run is not None
assert owner_run["user_id"] == "owner-1"
assert default_run is None
assert owner_thread is not None
assert owner_thread["user_id"] == "owner-1"
assert owner_thread["metadata"] == {"legacy": True}
assert default_thread is None
assert task_context["user_id"] == "owner-1"
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0) # build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+15
View File
@@ -33,3 +33,18 @@ def test_internal_auth_generates_process_local_fallback(monkeypatch):
assert reloaded.is_valid_internal_auth_token(token) is True assert reloaded.is_valid_internal_auth_token(token) is True
finally: finally:
importlib.reload(reloaded) importlib.reload(reloaded)
def test_internal_auth_headers_can_carry_owner_user_id(monkeypatch):
import app.gateway.internal_auth as internal_auth
monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token")
reloaded = importlib.reload(internal_auth)
try:
headers = reloaded.create_internal_auth_headers(owner_user_id="owner-1")
assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token"
assert headers[reloaded.INTERNAL_OWNER_USER_ID_HEADER_NAME] == "owner-1"
finally:
monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False)
importlib.reload(reloaded)
+13
View File
@@ -137,6 +137,19 @@ class TestThreadMetaRepository:
async def test_update_metadata_nonexistent_is_noop(self, repo): async def test_update_metadata_nonexistent_is_noop(self, repo):
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
@pytest.mark.anyio
async def test_update_owner_with_bypass_moves_row(self, repo):
await repo.create("t1", user_id="default", metadata={"source": "channel"})
await repo.update_owner("t1", "owner-1", user_id=None)
owner_row = await repo.get("t1", user_id="owner-1")
default_row = await repo.get("t1", user_id="default")
assert owner_row is not None
assert owner_row["user_id"] == "owner-1"
assert owner_row["metadata"] == {"source": "channel"}
assert default_row is None
# --- search with metadata filter (SQL push-down) --- # --- search with metadata filter (SQL push-down) ---
@pytest.mark.anyio @pytest.mark.anyio
+32
View File
@@ -1,4 +1,5 @@
import re import re
from types import SimpleNamespace
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@@ -218,6 +219,37 @@ def test_create_thread_returns_iso_timestamps() -> None:
assert body["created_at"] == body["updated_at"] assert body["created_at"] == body["updated_at"]
def test_internal_owner_header_assigns_thread_to_owner() -> None:
import asyncio
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
store = InMemoryStore()
checkpointer = InMemorySaver()
thread_store = MemoryThreadMetaStore(store)
request = SimpleNamespace(
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
app=SimpleNamespace(state=SimpleNamespace(checkpointer=checkpointer, thread_store=thread_store)),
)
async def _scenario():
response = await threads.create_thread(
threads.ThreadCreateRequest(thread_id="channel-thread", metadata={}),
request,
)
owner_row = await thread_store.get("channel-thread", user_id="owner-1")
internal_row = await thread_store.get("channel-thread", user_id="default")
return response, owner_row, internal_row
response, owner_row, internal_row = asyncio.run(_scenario())
assert response.thread_id == "channel-thread"
assert owner_row is not None
assert owner_row["user_id"] == "owner-1"
assert internal_row is None
def test_get_thread_returns_iso_for_legacy_unix_record() -> None: def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
"""A thread record written by older versions stores ``time.time()`` """A thread record written by older versions stores ``time.time()``
floats. ``get_thread`` must transparently surface them as ISO so the floats. ``get_thread`` must transparently surface them as ISO so the