mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Make channel threads visible to connection owners
This commit is contained in:
@@ -440,6 +440,12 @@ def _human_input_message(content: str, *, original_content: str | None = None) -
|
|||||||
return message
|
return message
|
||||||
|
|
||||||
|
|
||||||
|
def _owner_headers(msg: InboundMessage) -> dict[str, str] | None:
|
||||||
|
if not msg.owner_user_id:
|
||||||
|
return None
|
||||||
|
return create_internal_auth_headers(owner_user_id=msg.owner_user_id)
|
||||||
|
|
||||||
|
|
||||||
def _resolve_slash_skill_command(
|
def _resolve_slash_skill_command(
|
||||||
text: str,
|
text: str,
|
||||||
available_skills: set[str] | None = None,
|
available_skills: set[str] | None = None,
|
||||||
@@ -914,7 +920,11 @@ class ChannelManager:
|
|||||||
|
|
||||||
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
async def _create_thread(self, client, msg: InboundMessage) -> str:
|
||||||
"""Create a new thread through Gateway and store the mapping."""
|
"""Create a new thread through Gateway and store the mapping."""
|
||||||
thread = await client.threads.create()
|
owner_headers = _owner_headers(msg)
|
||||||
|
if owner_headers:
|
||||||
|
thread = await client.threads.create(headers=owner_headers)
|
||||||
|
else:
|
||||||
|
thread = await client.threads.create()
|
||||||
thread_id = thread["thread_id"]
|
thread_id = thread["thread_id"]
|
||||||
await self._store_thread_id(msg, thread_id)
|
await self._store_thread_id(msg, thread_id)
|
||||||
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
logger.info("[Manager] new thread created through Gateway: thread_id=%s for chat_id=%s topic_id=%s", thread_id, msg.chat_id, msg.topic_id)
|
||||||
@@ -969,14 +979,19 @@ class ChannelManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||||
|
run_kwargs: dict[str, Any] = {
|
||||||
|
"input": {"messages": [human_message]},
|
||||||
|
"config": run_config,
|
||||||
|
"context": run_context,
|
||||||
|
"multitask_strategy": "reject",
|
||||||
|
}
|
||||||
|
if owner_headers := _owner_headers(msg):
|
||||||
|
run_kwargs["headers"] = owner_headers
|
||||||
try:
|
try:
|
||||||
result = await client.runs.wait(
|
result = await client.runs.wait(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id,
|
assistant_id,
|
||||||
input={"messages": [human_message]},
|
**run_kwargs,
|
||||||
config=run_config,
|
|
||||||
context=run_context,
|
|
||||||
multitask_strategy="reject",
|
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
if _is_thread_busy_error(exc):
|
if _is_thread_busy_error(exc):
|
||||||
@@ -1039,16 +1054,21 @@ class ChannelManager:
|
|||||||
last_published_text = ""
|
last_published_text = ""
|
||||||
last_publish_at = 0.0
|
last_publish_at = 0.0
|
||||||
stream_error: BaseException | None = None
|
stream_error: BaseException | None = None
|
||||||
|
stream_kwargs: dict[str, Any] = {
|
||||||
|
"input": {"messages": [human_message]},
|
||||||
|
"config": run_config,
|
||||||
|
"context": run_context,
|
||||||
|
"stream_mode": ["messages-tuple", "values"],
|
||||||
|
"multitask_strategy": "reject",
|
||||||
|
}
|
||||||
|
if owner_headers := _owner_headers(msg):
|
||||||
|
stream_kwargs["headers"] = owner_headers
|
||||||
|
|
||||||
try:
|
try:
|
||||||
async for chunk in client.runs.stream(
|
async for chunk in client.runs.stream(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id,
|
assistant_id,
|
||||||
input={"messages": [human_message]},
|
**stream_kwargs,
|
||||||
config=run_config,
|
|
||||||
context=run_context,
|
|
||||||
stream_mode=["messages-tuple", "values"],
|
|
||||||
multitask_strategy="reject",
|
|
||||||
):
|
):
|
||||||
event = getattr(chunk, "event", "")
|
event = getattr(chunk, "event", "")
|
||||||
data = getattr(chunk, "data", None)
|
data = getattr(chunk, "data", None)
|
||||||
|
|||||||
@@ -276,6 +276,11 @@ def require_permission(
|
|||||||
# strict-deny rather than strict-allow — only an *existing*
|
# strict-deny rather than strict-allow — only an *existing*
|
||||||
# row with a *different* user_id triggers 404.
|
# row with a *different* user_id triggers 404.
|
||||||
if owner_check:
|
if owner_check:
|
||||||
|
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||||
|
|
||||||
|
if getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
|
||||||
thread_id = kwargs.get("thread_id")
|
thread_id = kwargs.get("thread_id")
|
||||||
if thread_id is None:
|
if thread_id is None:
|
||||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ from __future__ import annotations
|
|||||||
import os
|
import os
|
||||||
import secrets
|
import secrets
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||||
|
|
||||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
||||||
|
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
|
||||||
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
||||||
INTERNAL_SYSTEM_ROLE = "internal"
|
INTERNAL_SYSTEM_ROLE = "internal"
|
||||||
|
|
||||||
@@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str:
|
|||||||
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
||||||
|
|
||||||
|
|
||||||
def create_internal_auth_headers() -> dict[str, str]:
|
def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
|
||||||
"""Return headers that authenticate trusted Gateway internal calls."""
|
"""Return headers that authenticate trusted Gateway internal calls."""
|
||||||
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||||
|
if owner_user_id:
|
||||||
|
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||||
@@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
|
|||||||
def get_internal_user():
|
def get_internal_user():
|
||||||
"""Return the synthetic user used for trusted internal channel calls."""
|
"""Return the synthetic user used for trusted internal channel calls."""
|
||||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
||||||
|
|
||||||
|
|
||||||
|
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
|
||||||
|
"""Return the owner override for a trusted internal request, if present.
|
||||||
|
|
||||||
|
The header is ignored for normal browser/API callers. It is only honored
|
||||||
|
after ``AuthMiddleware`` has validated the internal auth token and stamped
|
||||||
|
the synthetic internal user onto ``request.state.user``.
|
||||||
|
"""
|
||||||
|
user = getattr(getattr(request, "state", None), "user", None)
|
||||||
|
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||||
|
return None
|
||||||
|
|
||||||
|
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
|
||||||
|
if not owner_user_id:
|
||||||
|
return None
|
||||||
|
owner_user_id = owner_user_id.strip()
|
||||||
|
return owner_user_id or None
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator
|
|||||||
|
|
||||||
from app.gateway.authz import require_permission
|
from app.gateway.authz import require_permission
|
||||||
from app.gateway.deps import get_checkpointer
|
from app.gateway.deps import get_checkpointer
|
||||||
|
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.paths import Paths, get_paths
|
from deerflow.config.paths import Paths, get_paths
|
||||||
from deerflow.runtime import serialize_channel_values
|
from deerflow.runtime import serialize_channel_values
|
||||||
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
thread_store = get_thread_store(request)
|
thread_store = get_thread_store(request)
|
||||||
thread_id = body.thread_id or str(uuid.uuid4())
|
thread_id = body.thread_id or str(uuid.uuid4())
|
||||||
now = now_iso()
|
now = now_iso()
|
||||||
|
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||||
|
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
|
||||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||||
|
|
||||||
# Idempotency: return existing record when already present
|
# Idempotency: return existing record when already present
|
||||||
existing_record = await thread_store.get(thread_id)
|
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||||
|
if existing_record is None and thread_owner_user_id:
|
||||||
|
unscoped_record = await thread_store.get(thread_id, user_id=None)
|
||||||
|
if unscoped_record is not None:
|
||||||
|
if unscoped_record.get("user_id") != thread_owner_user_id:
|
||||||
|
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
|
||||||
|
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||||
if existing_record is not None:
|
if existing_record is not None:
|
||||||
return ThreadResponse(
|
return ThreadResponse(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
await thread_store.create(
|
await thread_store.create(
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=getattr(body, "assistant_id", None),
|
assistant_id=getattr(body, "assistant_id", None),
|
||||||
|
**thread_owner_kwargs,
|
||||||
metadata=body.metadata,
|
metadata=body.metadata,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from collections.abc import Mapping
|
from collections.abc import Mapping
|
||||||
|
from types import SimpleNamespace
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import HTTPException, Request
|
from fastapi import HTTPException, Request
|
||||||
@@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage
|
|||||||
from langchain_core.messages.utils import convert_to_messages
|
from langchain_core.messages.utils import convert_to_messages
|
||||||
|
|
||||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.app_config import get_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.runtime import (
|
from deerflow.runtime import (
|
||||||
@@ -35,6 +36,7 @@ from deerflow.runtime import (
|
|||||||
run_agent,
|
run_agent,
|
||||||
)
|
)
|
||||||
from deerflow.runtime.runs.naming import resolve_root_run_name
|
from deerflow.runtime.runs.naming import resolve_root_run_name
|
||||||
|
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -315,72 +317,85 @@ async def start_run(
|
|||||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||||
|
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
|
||||||
try:
|
try:
|
||||||
record = await run_mgr.create_or_reject(
|
try:
|
||||||
thread_id,
|
record = await run_mgr.create_or_reject(
|
||||||
body.assistant_id,
|
|
||||||
on_disconnect=disconnect,
|
|
||||||
metadata=body.metadata or {},
|
|
||||||
kwargs={"input": body.input, "config": body.config},
|
|
||||||
multitask_strategy=body.multitask_strategy,
|
|
||||||
model_name=model_name,
|
|
||||||
)
|
|
||||||
except ConflictError as exc:
|
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
|
||||||
except UnsupportedStrategyError as exc:
|
|
||||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
|
||||||
|
|
||||||
# Upsert thread metadata so the thread appears in /threads/search,
|
|
||||||
# even for threads that were never explicitly created via POST /threads
|
|
||||||
# (e.g. stateless runs).
|
|
||||||
try:
|
|
||||||
existing = await run_ctx.thread_store.get(thread_id)
|
|
||||||
if existing is None:
|
|
||||||
await run_ctx.thread_store.create(
|
|
||||||
thread_id,
|
thread_id,
|
||||||
assistant_id=body.assistant_id,
|
body.assistant_id,
|
||||||
metadata=body.metadata,
|
on_disconnect=disconnect,
|
||||||
|
metadata=body.metadata or {},
|
||||||
|
kwargs={"input": body.input, "config": body.config},
|
||||||
|
multitask_strategy=body.multitask_strategy,
|
||||||
|
model_name=model_name,
|
||||||
|
user_id=owner_user_id,
|
||||||
)
|
)
|
||||||
else:
|
except ConflictError as exc:
|
||||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||||
except Exception:
|
except UnsupportedStrategyError as exc:
|
||||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||||
|
|
||||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
# Upsert thread metadata so the thread appears in /threads/search,
|
||||||
graph_input = normalize_input(body.input)
|
# even for threads that were never explicitly created via POST /threads
|
||||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
# (e.g. stateless runs).
|
||||||
|
try:
|
||||||
|
existing = await run_ctx.thread_store.get(thread_id)
|
||||||
|
if existing is None and owner_user_id:
|
||||||
|
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
|
||||||
|
if unscoped_existing is not None:
|
||||||
|
if unscoped_existing.get("user_id") != owner_user_id:
|
||||||
|
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
|
||||||
|
existing = await run_ctx.thread_store.get(thread_id)
|
||||||
|
if existing is None:
|
||||||
|
await run_ctx.thread_store.create(
|
||||||
|
thread_id,
|
||||||
|
assistant_id=body.assistant_id,
|
||||||
|
metadata=body.metadata,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||||
|
|
||||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
graph_input = normalize_input(body.input)
|
||||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
|
||||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
|
||||||
inject_authenticated_user_context(config, request)
|
|
||||||
|
|
||||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||||
|
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||||
|
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||||
|
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||||
|
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||||
|
inject_authenticated_user_context(config, request)
|
||||||
|
|
||||||
task = asyncio.create_task(
|
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||||
run_agent(
|
|
||||||
bridge,
|
task = asyncio.create_task(
|
||||||
run_mgr,
|
run_agent(
|
||||||
record,
|
bridge,
|
||||||
ctx=run_ctx,
|
run_mgr,
|
||||||
agent_factory=agent_factory,
|
record,
|
||||||
graph_input=graph_input,
|
ctx=run_ctx,
|
||||||
config=config,
|
agent_factory=agent_factory,
|
||||||
stream_modes=stream_modes,
|
graph_input=graph_input,
|
||||||
stream_subgraphs=body.stream_subgraphs,
|
config=config,
|
||||||
interrupt_before=body.interrupt_before,
|
stream_modes=stream_modes,
|
||||||
interrupt_after=body.interrupt_after,
|
stream_subgraphs=body.stream_subgraphs,
|
||||||
|
interrupt_before=body.interrupt_before,
|
||||||
|
interrupt_after=body.interrupt_after,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
)
|
record.task = task
|
||||||
record.task = task
|
|
||||||
|
|
||||||
# Title sync is handled by worker.py's finally block which reads the
|
# Title sync is handled by worker.py's finally block which reads the
|
||||||
# title from the checkpoint and calls thread_store.update_display_name
|
# title from the checkpoint and calls thread_store.update_display_name
|
||||||
# after the run completes.
|
# after the run completes.
|
||||||
|
|
||||||
return record
|
return record
|
||||||
|
finally:
|
||||||
|
if owner_context_token is not None:
|
||||||
|
reset_current_user(owner_context_token)
|
||||||
|
|
||||||
|
|
||||||
async def sse_consumer(
|
async def sse_consumer(
|
||||||
|
|||||||
@@ -71,6 +71,15 @@ class ThreadMetaStore(abc.ABC):
|
|||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
|
"""Move a thread metadata row to a new owner.
|
||||||
|
|
||||||
|
Intended for trusted internal repair/migration paths. No-op if the
|
||||||
|
row does not exist or the caller fails the owner check.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
async def check_access(self, thread_id: str, user_id: str, *, require_existing: bool = False) -> bool:
|
||||||
"""Check if ``user_id`` has access to ``thread_id``."""
|
"""Check if ``user_id`` has access to ``thread_id``."""
|
||||||
|
|||||||
@@ -127,6 +127,14 @@ class MemoryThreadMetaStore(ThreadMetaStore):
|
|||||||
record["updated_at"] = now_iso()
|
record["updated_at"] = now_iso()
|
||||||
await self._store.aput(THREADS_NS, thread_id, record)
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
|
async def update_owner(self, thread_id: str, owner_user_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
|
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.update_owner")
|
||||||
|
if record is None:
|
||||||
|
return
|
||||||
|
record["user_id"] = owner_user_id
|
||||||
|
record["updated_at"] = now_iso()
|
||||||
|
await self._store.aput(THREADS_NS, thread_id, record)
|
||||||
|
|
||||||
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
async def delete(self, thread_id: str, *, user_id: str | None | _AutoSentinel = AUTO) -> None:
|
||||||
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
record = await self._get_owned_record(thread_id, user_id, "MemoryThreadMetaStore.delete")
|
||||||
if record is None:
|
if record is None:
|
||||||
|
|||||||
@@ -211,6 +211,21 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
row.updated_at = datetime.now(UTC)
|
row.updated_at = datetime.now(UTC)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
async def update_owner(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
owner_user_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
|
) -> None:
|
||||||
|
"""Move a thread metadata row to ``owner_user_id``."""
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="ThreadMetaRepository.update_owner")
|
||||||
|
async with self._sf() as session:
|
||||||
|
if not await self._check_ownership(session, thread_id, resolved_user_id):
|
||||||
|
return
|
||||||
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(user_id=owner_user_id, updated_at=datetime.now(UTC)))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
async def delete(
|
async def delete(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
|
|||||||
@@ -83,6 +83,7 @@ class RunRecord:
|
|||||||
multitask_strategy: str = "reject"
|
multitask_strategy: str = "reject"
|
||||||
metadata: dict = field(default_factory=dict)
|
metadata: dict = field(default_factory=dict)
|
||||||
kwargs: dict = field(default_factory=dict)
|
kwargs: dict = field(default_factory=dict)
|
||||||
|
user_id: str | None = None
|
||||||
created_at: str = ""
|
created_at: str = ""
|
||||||
updated_at: str = ""
|
updated_at: str = ""
|
||||||
task: asyncio.Task | None = field(default=None, repr=False)
|
task: asyncio.Task | None = field(default=None, repr=False)
|
||||||
@@ -124,7 +125,7 @@ class RunManager:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
def _store_put_payload(record: RunRecord, *, error: str | None = None) -> dict[str, Any]:
|
||||||
return {
|
payload = {
|
||||||
"thread_id": record.thread_id,
|
"thread_id": record.thread_id,
|
||||||
"assistant_id": record.assistant_id,
|
"assistant_id": record.assistant_id,
|
||||||
"status": record.status.value,
|
"status": record.status.value,
|
||||||
@@ -135,6 +136,9 @@ class RunManager:
|
|||||||
"created_at": record.created_at,
|
"created_at": record.created_at,
|
||||||
"model_name": record.model_name,
|
"model_name": record.model_name,
|
||||||
}
|
}
|
||||||
|
if record.user_id is not None:
|
||||||
|
payload["user_id"] = record.user_id
|
||||||
|
return payload
|
||||||
|
|
||||||
async def _call_store_with_retry(
|
async def _call_store_with_retry(
|
||||||
self,
|
self,
|
||||||
@@ -241,6 +245,7 @@ class RunManager:
|
|||||||
kwargs=row.get("kwargs") or {},
|
kwargs=row.get("kwargs") or {},
|
||||||
created_at=row.get("created_at") or "",
|
created_at=row.get("created_at") or "",
|
||||||
updated_at=row.get("updated_at") or "",
|
updated_at=row.get("updated_at") or "",
|
||||||
|
user_id=row.get("user_id"),
|
||||||
error=row.get("error"),
|
error=row.get("error"),
|
||||||
model_name=row.get("model_name"),
|
model_name=row.get("model_name"),
|
||||||
store_only=True,
|
store_only=True,
|
||||||
@@ -320,6 +325,7 @@ class RunManager:
|
|||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
|
user_id: str | None = None,
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Create a new pending run and register it."""
|
"""Create a new pending run and register it."""
|
||||||
run_id = str(uuid.uuid4())
|
run_id = str(uuid.uuid4())
|
||||||
@@ -333,6 +339,7 @@ class RunManager:
|
|||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
kwargs=kwargs or {},
|
kwargs=kwargs or {},
|
||||||
|
user_id=user_id,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
)
|
)
|
||||||
@@ -504,6 +511,7 @@ class RunManager:
|
|||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
model_name: str | None = None,
|
model_name: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Atomically check for inflight runs and create a new one.
|
"""Atomically check for inflight runs and create a new one.
|
||||||
|
|
||||||
@@ -546,6 +554,7 @@ class RunManager:
|
|||||||
multitask_strategy=multitask_strategy,
|
multitask_strategy=multitask_strategy,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
kwargs=kwargs or {},
|
kwargs=kwargs or {},
|
||||||
|
user_id=user_id,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
|
|||||||
@@ -2388,6 +2388,7 @@ class TestResolveRunParamsUserId:
|
|||||||
class TestChannelManagerConnectionRouting:
|
class TestChannelManagerConnectionRouting:
|
||||||
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path):
|
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path):
|
||||||
from app.channels.manager import ChannelManager
|
from app.channels.manager import ChannelManager
|
||||||
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||||
from deerflow.persistence.engine import close_engine
|
from deerflow.persistence.engine import close_engine
|
||||||
|
|
||||||
async def go():
|
async def go():
|
||||||
@@ -2453,6 +2454,16 @@ class TestChannelManagerConnectionRouting:
|
|||||||
assert second_context["user_id"] == "bob"
|
assert second_context["user_id"] == "bob"
|
||||||
assert second_context["channel_user_id"] == "U-bob"
|
assert second_context["channel_user_id"] == "U-bob"
|
||||||
|
|
||||||
|
first_create_headers = mock_client.threads.create.call_args_list[0].kwargs["headers"]
|
||||||
|
second_create_headers = mock_client.threads.create.call_args_list[1].kwargs["headers"]
|
||||||
|
assert first_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
|
||||||
|
assert second_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
|
||||||
|
|
||||||
|
first_run_headers = mock_client.runs.wait.call_args_list[0].kwargs["headers"]
|
||||||
|
second_run_headers = mock_client.runs.wait.call_args_list[1].kwargs["headers"]
|
||||||
|
assert first_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
|
||||||
|
assert second_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
_run(go())
|
_run(go())
|
||||||
finally:
|
finally:
|
||||||
|
|||||||
@@ -474,6 +474,83 @@ def test_inject_authenticated_user_context_skips_internal_role():
|
|||||||
assert config["context"]["user_id"] == "channel-user-7"
|
assert config["context"]["user_id"] == "channel-user-7"
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_run_uses_internal_owner_header_for_persistence():
|
||||||
|
import asyncio
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from langgraph.checkpoint.memory import InMemorySaver
|
||||||
|
from langgraph.store.memory import InMemoryStore
|
||||||
|
|
||||||
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||||
|
from app.gateway.services import start_run
|
||||||
|
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||||
|
from deerflow.runtime import RunManager
|
||||||
|
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||||
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
|
async def _scenario():
|
||||||
|
run_store = MemoryRunStore()
|
||||||
|
thread_store = MemoryThreadMetaStore(InMemoryStore())
|
||||||
|
await thread_store.create("channel-thread", user_id="default", metadata={"legacy": True})
|
||||||
|
run_manager = RunManager(store=run_store)
|
||||||
|
state = SimpleNamespace(
|
||||||
|
stream_bridge=SimpleNamespace(),
|
||||||
|
run_manager=run_manager,
|
||||||
|
checkpointer=InMemorySaver(),
|
||||||
|
store=InMemoryStore(),
|
||||||
|
run_event_store=SimpleNamespace(),
|
||||||
|
run_events_config=None,
|
||||||
|
thread_store=thread_store,
|
||||||
|
)
|
||||||
|
request = SimpleNamespace(
|
||||||
|
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
|
||||||
|
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
|
||||||
|
app=SimpleNamespace(state=state),
|
||||||
|
)
|
||||||
|
body = SimpleNamespace(
|
||||||
|
assistant_id="lead_agent",
|
||||||
|
input={"messages": [{"role": "human", "content": "hi"}]},
|
||||||
|
metadata={},
|
||||||
|
config=None,
|
||||||
|
context=None,
|
||||||
|
on_disconnect="cancel",
|
||||||
|
multitask_strategy="reject",
|
||||||
|
stream_mode=None,
|
||||||
|
stream_subgraphs=False,
|
||||||
|
interrupt_before=None,
|
||||||
|
interrupt_after=None,
|
||||||
|
)
|
||||||
|
task_context: dict[str, str] = {}
|
||||||
|
|
||||||
|
async def fake_run_agent(*args, **kwargs):
|
||||||
|
task_context["user_id"] = get_effective_user_id()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.gateway.services.resolve_agent_factory", return_value=object()),
|
||||||
|
patch("app.gateway.services.run_agent", side_effect=fake_run_agent),
|
||||||
|
):
|
||||||
|
record = await start_run(body, "channel-thread", request)
|
||||||
|
await record.task
|
||||||
|
|
||||||
|
owner_run = await run_store.get(record.run_id, user_id="owner-1")
|
||||||
|
default_run = await run_store.get(record.run_id, user_id="default")
|
||||||
|
owner_thread = await thread_store.get("channel-thread", user_id="owner-1")
|
||||||
|
default_thread = await thread_store.get("channel-thread", user_id="default")
|
||||||
|
return owner_run, default_run, owner_thread, default_thread, task_context
|
||||||
|
|
||||||
|
owner_run, default_run, owner_thread, default_thread, task_context = asyncio.run(_scenario())
|
||||||
|
|
||||||
|
assert owner_run is not None
|
||||||
|
assert owner_run["user_id"] == "owner-1"
|
||||||
|
assert default_run is None
|
||||||
|
assert owner_thread is not None
|
||||||
|
assert owner_thread["user_id"] == "owner-1"
|
||||||
|
assert owner_thread["metadata"] == {"legacy": True}
|
||||||
|
assert default_thread is None
|
||||||
|
assert task_context["user_id"] == "owner-1"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -33,3 +33,18 @@ def test_internal_auth_generates_process_local_fallback(monkeypatch):
|
|||||||
assert reloaded.is_valid_internal_auth_token(token) is True
|
assert reloaded.is_valid_internal_auth_token(token) is True
|
||||||
finally:
|
finally:
|
||||||
importlib.reload(reloaded)
|
importlib.reload(reloaded)
|
||||||
|
|
||||||
|
|
||||||
|
def test_internal_auth_headers_can_carry_owner_user_id(monkeypatch):
|
||||||
|
import app.gateway.internal_auth as internal_auth
|
||||||
|
|
||||||
|
monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token")
|
||||||
|
reloaded = importlib.reload(internal_auth)
|
||||||
|
try:
|
||||||
|
headers = reloaded.create_internal_auth_headers(owner_user_id="owner-1")
|
||||||
|
|
||||||
|
assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token"
|
||||||
|
assert headers[reloaded.INTERNAL_OWNER_USER_ID_HEADER_NAME] == "owner-1"
|
||||||
|
finally:
|
||||||
|
monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False)
|
||||||
|
importlib.reload(reloaded)
|
||||||
|
|||||||
@@ -137,6 +137,19 @@ class TestThreadMetaRepository:
|
|||||||
async def test_update_metadata_nonexistent_is_noop(self, repo):
|
async def test_update_metadata_nonexistent_is_noop(self, repo):
|
||||||
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
|
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_update_owner_with_bypass_moves_row(self, repo):
|
||||||
|
await repo.create("t1", user_id="default", metadata={"source": "channel"})
|
||||||
|
await repo.update_owner("t1", "owner-1", user_id=None)
|
||||||
|
|
||||||
|
owner_row = await repo.get("t1", user_id="owner-1")
|
||||||
|
default_row = await repo.get("t1", user_id="default")
|
||||||
|
|
||||||
|
assert owner_row is not None
|
||||||
|
assert owner_row["user_id"] == "owner-1"
|
||||||
|
assert owner_row["metadata"] == {"source": "channel"}
|
||||||
|
assert default_row is None
|
||||||
|
|
||||||
# --- search with metadata filter (SQL push-down) ---
|
# --- search with metadata filter (SQL push-down) ---
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import re
|
import re
|
||||||
|
from types import SimpleNamespace
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -218,6 +219,37 @@ def test_create_thread_returns_iso_timestamps() -> None:
|
|||||||
assert body["created_at"] == body["updated_at"]
|
assert body["created_at"] == body["updated_at"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_internal_owner_header_assigns_thread_to_owner() -> None:
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||||
|
|
||||||
|
store = InMemoryStore()
|
||||||
|
checkpointer = InMemorySaver()
|
||||||
|
thread_store = MemoryThreadMetaStore(store)
|
||||||
|
request = SimpleNamespace(
|
||||||
|
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
|
||||||
|
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
|
||||||
|
app=SimpleNamespace(state=SimpleNamespace(checkpointer=checkpointer, thread_store=thread_store)),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _scenario():
|
||||||
|
response = await threads.create_thread(
|
||||||
|
threads.ThreadCreateRequest(thread_id="channel-thread", metadata={}),
|
||||||
|
request,
|
||||||
|
)
|
||||||
|
owner_row = await thread_store.get("channel-thread", user_id="owner-1")
|
||||||
|
internal_row = await thread_store.get("channel-thread", user_id="default")
|
||||||
|
return response, owner_row, internal_row
|
||||||
|
|
||||||
|
response, owner_row, internal_row = asyncio.run(_scenario())
|
||||||
|
|
||||||
|
assert response.thread_id == "channel-thread"
|
||||||
|
assert owner_row is not None
|
||||||
|
assert owner_row["user_id"] == "owner-1"
|
||||||
|
assert internal_row is None
|
||||||
|
|
||||||
|
|
||||||
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
|
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
|
||||||
"""A thread record written by older versions stores ``time.time()``
|
"""A thread record written by older versions stores ``time.time()``
|
||||||
floats. ``get_thread`` must transparently surface them as ISO so the
|
floats. ``get_thread`` must transparently surface them as ISO so the
|
||||||
|
|||||||
Reference in New Issue
Block a user