mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 18:05:58 +00:00
Make channel threads visible to connection owners
This commit is contained in:
@@ -276,6 +276,11 @@ def require_permission(
|
||||
# strict-deny rather than strict-allow — only an *existing*
|
||||
# row with a *different* user_id triggers 404.
|
||||
if owner_check:
|
||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||
|
||||
if getattr(auth.user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||
return await func(*args, **kwargs)
|
||||
|
||||
thread_id = kwargs.get("thread_id")
|
||||
if thread_id is None:
|
||||
raise ValueError("require_permission with owner_check=True requires 'thread_id' parameter")
|
||||
|
||||
@@ -5,10 +5,12 @@ from __future__ import annotations
|
||||
import os
|
||||
import secrets
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from deerflow.runtime.user_context import DEFAULT_USER_ID
|
||||
|
||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
||||
INTERNAL_OWNER_USER_ID_HEADER_NAME = "X-DeerFlow-Owner-User-Id"
|
||||
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
||||
INTERNAL_SYSTEM_ROLE = "internal"
|
||||
|
||||
@@ -23,9 +25,12 @@ def _load_internal_auth_token() -> str:
|
||||
_INTERNAL_AUTH_TOKEN = _load_internal_auth_token()
|
||||
|
||||
|
||||
def create_internal_auth_headers() -> dict[str, str]:
|
||||
def create_internal_auth_headers(*, owner_user_id: str | None = None) -> dict[str, str]:
|
||||
"""Return headers that authenticate trusted Gateway internal calls."""
|
||||
return {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||
headers = {INTERNAL_AUTH_HEADER_NAME: _INTERNAL_AUTH_TOKEN}
|
||||
if owner_user_id:
|
||||
headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] = owner_user_id
|
||||
return headers
|
||||
|
||||
|
||||
def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||
@@ -36,3 +41,21 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
|
||||
def get_internal_user():
|
||||
"""Return the synthetic user used for trusted internal channel calls."""
|
||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
||||
|
||||
|
||||
def get_trusted_internal_owner_user_id(request: Any) -> str | None:
|
||||
"""Return the owner override for a trusted internal request, if present.
|
||||
|
||||
The header is ignored for normal browser/API callers. It is only honored
|
||||
after ``AuthMiddleware`` has validated the internal auth token and stamped
|
||||
the synthetic internal user onto ``request.state.user``.
|
||||
"""
|
||||
user = getattr(getattr(request, "state", None), "user", None)
|
||||
if getattr(user, "system_role", None) != INTERNAL_SYSTEM_ROLE:
|
||||
return None
|
||||
|
||||
owner_user_id = request.headers.get(INTERNAL_OWNER_USER_ID_HEADER_NAME)
|
||||
if not owner_user_id:
|
||||
return None
|
||||
owner_user_id = owner_user_id.strip()
|
||||
return owner_user_id or None
|
||||
|
||||
@@ -22,6 +22,7 @@ from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from app.gateway.authz import require_permission
|
||||
from app.gateway.deps import get_checkpointer
|
||||
from app.gateway.internal_auth import get_trusted_internal_owner_user_id
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.paths import Paths, get_paths
|
||||
from deerflow.runtime import serialize_channel_values
|
||||
@@ -257,11 +258,19 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
thread_store = get_thread_store(request)
|
||||
thread_id = body.thread_id or str(uuid.uuid4())
|
||||
now = now_iso()
|
||||
thread_owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||
thread_owner_kwargs = {"user_id": thread_owner_user_id} if thread_owner_user_id else {}
|
||||
# ``body.metadata`` is already stripped of server-reserved keys by
|
||||
# ``ThreadCreateRequest._strip_reserved`` — see the model definition.
|
||||
|
||||
# Idempotency: return existing record when already present
|
||||
existing_record = await thread_store.get(thread_id)
|
||||
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||
if existing_record is None and thread_owner_user_id:
|
||||
unscoped_record = await thread_store.get(thread_id, user_id=None)
|
||||
if unscoped_record is not None:
|
||||
if unscoped_record.get("user_id") != thread_owner_user_id:
|
||||
await thread_store.update_owner(thread_id, thread_owner_user_id, user_id=None)
|
||||
existing_record = await thread_store.get(thread_id, **thread_owner_kwargs)
|
||||
if existing_record is not None:
|
||||
return ThreadResponse(
|
||||
thread_id=thread_id,
|
||||
@@ -276,6 +285,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
||||
await thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=getattr(body, "assistant_id", None),
|
||||
**thread_owner_kwargs,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -12,6 +12,7 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
@@ -19,7 +20,7 @@ from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages.utils import convert_to_messages
|
||||
|
||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE, get_trusted_internal_owner_user_id
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
from deerflow.config.app_config import get_app_config
|
||||
from deerflow.runtime import (
|
||||
@@ -35,6 +36,7 @@ from deerflow.runtime import (
|
||||
run_agent,
|
||||
)
|
||||
from deerflow.runtime.runs.naming import resolve_root_run_name
|
||||
from deerflow.runtime.user_context import reset_current_user, set_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -315,72 +317,85 @@ async def start_run(
|
||||
detail=f"Model {model_name!r} is not in the configured model allowlist",
|
||||
)
|
||||
|
||||
owner_user_id = get_trusted_internal_owner_user_id(request)
|
||||
owner_context_token = set_current_user(SimpleNamespace(id=owner_user_id)) if owner_user_id else None
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
body.assistant_id,
|
||||
on_disconnect=disconnect,
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
model_name=model_name,
|
||||
)
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except UnsupportedStrategyError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
# Upsert thread metadata so the thread appears in /threads/search,
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_store.create(
|
||||
try:
|
||||
record = await run_mgr.create_or_reject(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
body.assistant_id,
|
||||
on_disconnect=disconnect,
|
||||
metadata=body.metadata or {},
|
||||
kwargs={"input": body.input, "config": body.config},
|
||||
multitask_strategy=body.multitask_strategy,
|
||||
model_name=model_name,
|
||||
user_id=owner_user_id,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
except ConflictError as exc:
|
||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||
except UnsupportedStrategyError as exc:
|
||||
raise HTTPException(status_code=501, detail=str(exc)) from exc
|
||||
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
# Upsert thread metadata so the thread appears in /threads/search,
|
||||
# even for threads that were never explicitly created via POST /threads
|
||||
# (e.g. stateless runs).
|
||||
try:
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None and owner_user_id:
|
||||
unscoped_existing = await run_ctx.thread_store.get(thread_id, user_id=None)
|
||||
if unscoped_existing is not None:
|
||||
if unscoped_existing.get("user_id") != owner_user_id:
|
||||
await run_ctx.thread_store.update_owner(thread_id, owner_user_id, user_id=None)
|
||||
existing = await run_ctx.thread_store.get(thread_id)
|
||||
if existing is None:
|
||||
await run_ctx.thread_store.create(
|
||||
thread_id,
|
||||
assistant_id=body.assistant_id,
|
||||
metadata=body.metadata,
|
||||
)
|
||||
else:
|
||||
await run_ctx.thread_store.update_status(thread_id, "running")
|
||||
except Exception:
|
||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", sanitize_log_param(thread_id))
|
||||
|
||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||
inject_authenticated_user_context(config, request)
|
||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||
graph_input = normalize_input(body.input)
|
||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
# Merge DeerFlow-specific context overrides into both ``configurable`` and ``context``.
|
||||
# The ``context`` field is a custom extension for the langgraph-compat layer
|
||||
# that carries agent configuration (model_name, thinking_enabled, etc.).
|
||||
# Only agent-relevant keys are forwarded; unknown keys (e.g. thread_id) are ignored.
|
||||
merge_run_context_overrides(config, getattr(body, "context", None))
|
||||
inject_authenticated_user_context(config, request)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_agent(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
ctx=run_ctx,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
stream_modes = normalize_stream_modes(body.stream_mode)
|
||||
|
||||
task = asyncio.create_task(
|
||||
run_agent(
|
||||
bridge,
|
||||
run_mgr,
|
||||
record,
|
||||
ctx=run_ctx,
|
||||
agent_factory=agent_factory,
|
||||
graph_input=graph_input,
|
||||
config=config,
|
||||
stream_modes=stream_modes,
|
||||
stream_subgraphs=body.stream_subgraphs,
|
||||
interrupt_before=body.interrupt_before,
|
||||
interrupt_after=body.interrupt_after,
|
||||
)
|
||||
)
|
||||
)
|
||||
record.task = task
|
||||
record.task = task
|
||||
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_store.update_display_name
|
||||
# after the run completes.
|
||||
# Title sync is handled by worker.py's finally block which reads the
|
||||
# title from the checkpoint and calls thread_store.update_display_name
|
||||
# after the run completes.
|
||||
|
||||
return record
|
||||
return record
|
||||
finally:
|
||||
if owner_context_token is not None:
|
||||
reset_current_user(owner_context_token)
|
||||
|
||||
|
||||
async def sse_consumer(
|
||||
|
||||
Reference in New Issue
Block a user