mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
fix(mcp): add auth interceptor with channel user_id and keep header propagation to mcp tools (#3294)
* 修复channel中的user_id传递到interceptor中的bug, mcp可通过header传递user_id到mcp工具 Co-authored-by: Cursor <cursoragent@cursor.com> * fix(channel,mcp,gateway): normalize channel user_id and add regression tests Normalize external channel user ids into filesystem-safe runtime context while preserving raw channel_user_id, and document gateway user_id propagation semantics. Add regression coverage for channel user_id context mapping, gateway user_id precedence/internal-role behavior, and MCP interceptor header forwarding via meta.headers. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(auth,mcp): harden user id normalization and header handling Increase sanitized user-id digest suffix to 16 hex chars, replace internal system role magic string with a shared constant, and harden MCP header forwarding with Mapping type checks. Add regression tests for empty channel user_id handling, unsupported header types, and updated digest length behavior. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: zhongli <335302680@qq.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -26,6 +26,7 @@ from app.channels.message_bus import (
|
|||||||
from app.channels.store import ChannelStore
|
from app.channels.store import ChannelStore
|
||||||
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
|
from app.gateway.csrf_middleware import CSRF_COOKIE_NAME, CSRF_HEADER_NAME, generate_csrf_token
|
||||||
from app.gateway.internal_auth import create_internal_auth_headers
|
from app.gateway.internal_auth import create_internal_auth_headers
|
||||||
|
from deerflow.config.paths import make_safe_user_id
|
||||||
from deerflow.runtime.user_context import get_effective_user_id
|
from deerflow.runtime.user_context import get_effective_user_id
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -670,12 +671,20 @@ class ChannelManager:
|
|||||||
configurable["checkpoint_ns"] = ""
|
configurable["checkpoint_ns"] = ""
|
||||||
configurable["thread_id"] = thread_id
|
configurable["thread_id"] = thread_id
|
||||||
|
|
||||||
|
# ``user_id`` drives user-scoped filesystem buckets that only accept
|
||||||
|
# ``[A-Za-z0-9_-]``, so normalize the channel id and keep the raw value
|
||||||
|
# under ``channel_user_id`` for platform-facing lookups.
|
||||||
|
run_context_identity: dict[str, Any] = {"thread_id": thread_id}
|
||||||
|
if msg.user_id:
|
||||||
|
run_context_identity["user_id"] = make_safe_user_id(msg.user_id)
|
||||||
|
run_context_identity["channel_user_id"] = msg.user_id
|
||||||
|
|
||||||
run_context = _merge_dicts(
|
run_context = _merge_dicts(
|
||||||
DEFAULT_RUN_CONTEXT,
|
DEFAULT_RUN_CONTEXT,
|
||||||
self._default_session.get("context"),
|
self._default_session.get("context"),
|
||||||
channel_layer.get("context"),
|
channel_layer.get("context"),
|
||||||
user_layer.get("context"),
|
user_layer.get("context"),
|
||||||
{"thread_id": thread_id},
|
run_context_identity,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Custom agents are implemented as lead_agent + agent_name context.
|
# Custom agents are implemented as lead_agent + agent_name context.
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from deerflow.runtime.user_context import DEFAULT_USER_ID
|
|||||||
|
|
||||||
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
INTERNAL_AUTH_HEADER_NAME = "X-DeerFlow-Internal-Token"
|
||||||
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
INTERNAL_AUTH_ENV_VAR = "DEER_FLOW_INTERNAL_AUTH_TOKEN"
|
||||||
|
INTERNAL_SYSTEM_ROLE = "internal"
|
||||||
|
|
||||||
|
|
||||||
def _load_internal_auth_token() -> str:
|
def _load_internal_auth_token() -> str:
|
||||||
@@ -34,4 +35,4 @@ def is_valid_internal_auth_token(token: str | None) -> bool:
|
|||||||
|
|
||||||
def get_internal_user():
|
def get_internal_user():
|
||||||
"""Return the synthetic user used for trusted internal channel calls."""
|
"""Return the synthetic user used for trusted internal channel calls."""
|
||||||
return SimpleNamespace(id=DEFAULT_USER_ID, system_role="internal")
|
return SimpleNamespace(id=DEFAULT_USER_ID, system_role=INTERNAL_SYSTEM_ROLE)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from langchain_core.messages import BaseMessage
|
|||||||
from langchain_core.messages.utils import convert_to_messages
|
from langchain_core.messages.utils import convert_to_messages
|
||||||
|
|
||||||
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
from app.gateway.deps import get_run_context, get_run_manager, get_stream_bridge
|
||||||
|
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||||
from app.gateway.utils import sanitize_log_param
|
from app.gateway.utils import sanitize_log_param
|
||||||
from deerflow.config.app_config import get_app_config
|
from deerflow.config.app_config import get_app_config
|
||||||
from deerflow.runtime import (
|
from deerflow.runtime import (
|
||||||
@@ -140,7 +141,14 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
|
|||||||
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
|
"""Merge whitelisted keys from ``body.context`` into both ``config['configurable']``
|
||||||
and ``config['context']`` so they are visible to legacy configurable readers and
|
and ``config['context']`` so they are visible to legacy configurable readers and
|
||||||
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
|
to LangGraph ``ToolRuntime.context`` consumers (e.g. the ``setup_agent`` tool —
|
||||||
see issue #2677)."""
|
see issue #2677).
|
||||||
|
|
||||||
|
``user_id`` is intentionally propagated into ``config['context']`` in addition to
|
||||||
|
the whitelisted keys, so non-web callers (e.g. IM channels) that supply identity in
|
||||||
|
``body.context`` keep it on ``ToolRuntime.context``. It is merged with
|
||||||
|
``setdefault`` so a server-authenticated id stamped by
|
||||||
|
:func:`inject_authenticated_user_context` always wins over the client-supplied one.
|
||||||
|
"""
|
||||||
if not context:
|
if not context:
|
||||||
return
|
return
|
||||||
configurable = config.setdefault("configurable", {})
|
configurable = config.setdefault("configurable", {})
|
||||||
@@ -151,6 +159,8 @@ def merge_run_context_overrides(config: dict[str, Any], context: Mapping[str, An
|
|||||||
configurable.setdefault(key, context[key])
|
configurable.setdefault(key, context[key])
|
||||||
if isinstance(runtime_context, dict):
|
if isinstance(runtime_context, dict):
|
||||||
runtime_context.setdefault(key, context[key])
|
runtime_context.setdefault(key, context[key])
|
||||||
|
if "user_id" in context and isinstance(runtime_context, dict):
|
||||||
|
runtime_context.setdefault("user_id", context["user_id"])
|
||||||
|
|
||||||
|
|
||||||
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
def inject_authenticated_user_context(config: dict[str, Any], request: Request) -> None:
|
||||||
@@ -166,6 +176,9 @@ def inject_authenticated_user_context(config: dict[str, Any], request: Request)
|
|||||||
if user_id is None:
|
if user_id is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if getattr(user, "system_role", None) == INTERNAL_SYSTEM_ROLE:
|
||||||
|
return
|
||||||
|
|
||||||
runtime_context = config.setdefault("context", {})
|
runtime_context = config.setdefault("context", {})
|
||||||
if isinstance(runtime_context, dict):
|
if isinstance(runtime_context, dict):
|
||||||
runtime_context["user_id"] = str(user_id)
|
runtime_context["user_id"] = str(user_id)
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import hashlib
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
@@ -10,6 +11,8 @@ VIRTUAL_PATH_PREFIX = "/mnt/user-data"
|
|||||||
|
|
||||||
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
_SAFE_THREAD_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||||
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
_SAFE_USER_ID_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||||
|
_UNSAFE_USER_ID_CHAR_RE = re.compile(r"[^A-Za-z0-9_\-]")
|
||||||
|
_SAFE_USER_ID_DIGEST_HEX_LEN = 16
|
||||||
|
|
||||||
|
|
||||||
def _default_local_base_dir() -> Path:
|
def _default_local_base_dir() -> Path:
|
||||||
@@ -31,6 +34,23 @@ def _validate_user_id(user_id: str) -> str:
|
|||||||
return user_id
|
return user_id
|
||||||
|
|
||||||
|
|
||||||
|
def make_safe_user_id(raw: str) -> str:
|
||||||
|
"""Normalize an external identity into the user-id charset (``[A-Za-z0-9_-]``).
|
||||||
|
|
||||||
|
IM channel ids (Feishu/Slack/Telegram) may contain characters that
|
||||||
|
:func:`_validate_user_id` rejects. Already-safe ids pass through unchanged;
|
||||||
|
lossy ones get a short digest suffix so two distinct inputs never share a
|
||||||
|
storage bucket.
|
||||||
|
"""
|
||||||
|
if not raw:
|
||||||
|
raise ValueError("user_id must be a non-empty string.")
|
||||||
|
sanitized = _UNSAFE_USER_ID_CHAR_RE.sub("-", raw)
|
||||||
|
if sanitized == raw:
|
||||||
|
return raw
|
||||||
|
digest = hashlib.sha1(raw.encode("utf-8")).hexdigest()[:_SAFE_USER_ID_DIGEST_HEX_LEN]
|
||||||
|
return f"{sanitized}-{digest}"
|
||||||
|
|
||||||
|
|
||||||
def _join_host_path(base: str, *parts: str) -> str:
|
def _join_host_path(base: str, *parts: str) -> str:
|
||||||
"""Join host filesystem path segments while preserving native style.
|
"""Join host filesystem path segments while preserving native style.
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import Mapping
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import BaseTool, StructuredTool
|
from langchain_core.tools import BaseTool, StructuredTool
|
||||||
@@ -137,7 +138,15 @@ def _make_session_pool_tool(
|
|||||||
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
|
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
|
||||||
|
|
||||||
async def base_handler(request: MCPToolCallRequest) -> Any:
|
async def base_handler(request: MCPToolCallRequest) -> Any:
|
||||||
return await session.call_tool(request.name, request.args)
|
# Preserve interceptor-injected headers for stdio MCP calls by
|
||||||
|
# forwarding them through MCP call meta.
|
||||||
|
call_kwargs: dict[str, Any] = {}
|
||||||
|
if request.headers:
|
||||||
|
if isinstance(request.headers, Mapping):
|
||||||
|
call_kwargs["meta"] = {"headers": dict(request.headers)}
|
||||||
|
else:
|
||||||
|
logger.warning("Ignoring MCP interceptor headers with unsupported type: %s", type(request.headers).__name__)
|
||||||
|
return await session.call_tool(request.name, request.args, **call_kwargs)
|
||||||
|
|
||||||
handler = base_handler
|
handler = base_handler
|
||||||
for interceptor in reversed(tool_interceptors):
|
for interceptor in reversed(tool_interceptors):
|
||||||
|
|||||||
@@ -1787,6 +1787,51 @@ class TestChannelManager:
|
|||||||
_run(go())
|
_run(go())
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveRunParamsUserId:
|
||||||
|
"""Regression for PR #3294: channel identity must reach ``run_context``
|
||||||
|
while staying safe for user-scoped filesystem buckets.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _manager(self):
|
||||||
|
from app.channels.manager import ChannelManager
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||||
|
return ChannelManager(bus=bus, store=store)
|
||||||
|
|
||||||
|
def test_safe_user_id_is_passed_through(self):
|
||||||
|
manager = self._manager()
|
||||||
|
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert run_context["user_id"] == "123456"
|
||||||
|
assert run_context["channel_user_id"] == "123456"
|
||||||
|
|
||||||
|
def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
|
||||||
|
from deerflow.config.paths import make_safe_user_id
|
||||||
|
|
||||||
|
manager = self._manager()
|
||||||
|
raw = "user@example.com"
|
||||||
|
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert run_context["user_id"] == make_safe_user_id(raw)
|
||||||
|
assert run_context["user_id"] != raw
|
||||||
|
assert run_context["channel_user_id"] == raw
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("raw_user_id", ["", None])
|
||||||
|
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id):
|
||||||
|
manager = self._manager()
|
||||||
|
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
|
||||||
|
|
||||||
|
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||||
|
|
||||||
|
assert "user_id" not in run_context
|
||||||
|
assert "channel_user_id" not in run_context
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# ChannelService tests
|
# ChannelService tests
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -431,6 +431,49 @@ def test_inject_authenticated_user_context_overrides_client_user_id():
|
|||||||
assert config["context"]["user_id"] == "auth-user-42"
|
assert config["context"]["user_id"] == "auth-user-42"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_run_context_overrides_propagates_user_id():
|
||||||
|
"""Regression for PR #3294: ``user_id`` from ``body.context`` must land in
|
||||||
|
``config['context']`` so non-web callers (e.g. IM channels) keep their identity
|
||||||
|
on ``ToolRuntime.context``.
|
||||||
|
"""
|
||||||
|
from app.gateway.services import build_run_config, merge_run_context_overrides
|
||||||
|
|
||||||
|
config = build_run_config("thread-1", None, None)
|
||||||
|
merge_run_context_overrides(config, {"user_id": "channel-user-7"})
|
||||||
|
|
||||||
|
assert config["context"]["user_id"] == "channel-user-7"
|
||||||
|
|
||||||
|
|
||||||
|
def test_merge_run_context_overrides_does_not_clobber_existing_user_id():
|
||||||
|
"""``merge_run_context_overrides`` must not override an already-stamped
|
||||||
|
authenticated ``context.user_id`` with the client-supplied value.
|
||||||
|
"""
|
||||||
|
from app.gateway.services import build_run_config, merge_run_context_overrides
|
||||||
|
|
||||||
|
config = build_run_config("thread-1", {"context": {"user_id": "auth-user-42"}}, None)
|
||||||
|
merge_run_context_overrides(config, {"user_id": "spoofed-client"})
|
||||||
|
|
||||||
|
assert config["context"]["user_id"] == "auth-user-42"
|
||||||
|
|
||||||
|
|
||||||
|
def test_inject_authenticated_user_context_skips_internal_role():
|
||||||
|
"""Regression for PR #3294: internal system-role callers must not overwrite an
|
||||||
|
already-present ``context.user_id`` (e.g. a channel-supplied identity), so the
|
||||||
|
real end user keeps owning the per-user storage bucket.
|
||||||
|
"""
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from app.gateway.services import build_run_config, inject_authenticated_user_context
|
||||||
|
|
||||||
|
config = build_run_config("thread-1", None, None)
|
||||||
|
config["context"] = {"user_id": "channel-user-7"}
|
||||||
|
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="internal-bot", system_role="internal")))
|
||||||
|
|
||||||
|
inject_authenticated_user_context(config, request)
|
||||||
|
|
||||||
|
assert config["context"]["user_id"] == "channel-user-7"
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -256,6 +256,136 @@ async def test_session_pool_tool_wrapping():
|
|||||||
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
|
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_pool_tool_forwards_interceptor_headers():
|
||||||
|
"""Regression for PR #3294: when an interceptor sets ``request.headers``, the
|
||||||
|
pooled stdio call must forward them via ``meta={"headers": ...}`` so downstream
|
||||||
|
MCP servers can read auth/context headers.
|
||||||
|
"""
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from deerflow.mcp.tools import _make_session_pool_tool
|
||||||
|
|
||||||
|
class Args(BaseModel):
|
||||||
|
x: int = Field(..., description="x")
|
||||||
|
|
||||||
|
original_tool = StructuredTool(
|
||||||
|
name="srv_act",
|
||||||
|
description="test",
|
||||||
|
args_schema=Args,
|
||||||
|
coroutine=AsyncMock(),
|
||||||
|
response_format="content_and_artifact",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||||
|
mock_cm = MagicMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
async def header_interceptor(request, handler):
|
||||||
|
return await handler(request.override(headers={"X-User-Id": "u-42"}))
|
||||||
|
|
||||||
|
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||||
|
wrapped = _make_session_pool_tool(
|
||||||
|
original_tool,
|
||||||
|
"srv",
|
||||||
|
{"transport": "stdio", "command": "x", "args": []},
|
||||||
|
tool_interceptors=[header_interceptor],
|
||||||
|
)
|
||||||
|
await wrapped.coroutine(runtime=None, x=1)
|
||||||
|
|
||||||
|
mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}, meta={"headers": {"X-User-Id": "u-42"}})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_pool_tool_no_headers_omits_meta():
|
||||||
|
"""When no interceptor sets headers, the pooled call must not pass a ``meta``
|
||||||
|
kwarg (falls back to the plain two-argument ``call_tool``).
|
||||||
|
"""
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from deerflow.mcp.tools import _make_session_pool_tool
|
||||||
|
|
||||||
|
class Args(BaseModel):
|
||||||
|
x: int = Field(..., description="x")
|
||||||
|
|
||||||
|
original_tool = StructuredTool(
|
||||||
|
name="srv_act",
|
||||||
|
description="test",
|
||||||
|
args_schema=Args,
|
||||||
|
coroutine=AsyncMock(),
|
||||||
|
response_format="content_and_artifact",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||||
|
mock_cm = MagicMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
async def passthrough_interceptor(request, handler):
|
||||||
|
return await handler(request)
|
||||||
|
|
||||||
|
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||||
|
wrapped = _make_session_pool_tool(
|
||||||
|
original_tool,
|
||||||
|
"srv",
|
||||||
|
{"transport": "stdio", "command": "x", "args": []},
|
||||||
|
tool_interceptors=[passthrough_interceptor],
|
||||||
|
)
|
||||||
|
await wrapped.coroutine(runtime=None, x=1)
|
||||||
|
|
||||||
|
mock_session.call_tool.assert_awaited_once_with("act", {"x": 1})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_session_pool_tool_ignores_unsupported_header_type(caplog):
|
||||||
|
"""Defensive path: non-mapping truthy headers should be ignored safely."""
|
||||||
|
from langchain_core.tools import StructuredTool
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from deerflow.mcp.tools import _make_session_pool_tool
|
||||||
|
|
||||||
|
class Args(BaseModel):
|
||||||
|
x: int = Field(..., description="x")
|
||||||
|
|
||||||
|
class TruthyHeaders:
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
original_tool = StructuredTool(
|
||||||
|
name="srv_act",
|
||||||
|
description="test",
|
||||||
|
args_schema=Args,
|
||||||
|
coroutine=AsyncMock(),
|
||||||
|
response_format="content_and_artifact",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_session = AsyncMock()
|
||||||
|
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||||
|
mock_cm = MagicMock()
|
||||||
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||||
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
async def invalid_header_interceptor(request, handler):
|
||||||
|
return await handler(request.override(headers=TruthyHeaders()))
|
||||||
|
|
||||||
|
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||||
|
wrapped = _make_session_pool_tool(
|
||||||
|
original_tool,
|
||||||
|
"srv",
|
||||||
|
{"transport": "stdio", "command": "x", "args": []},
|
||||||
|
tool_interceptors=[invalid_header_interceptor],
|
||||||
|
)
|
||||||
|
await wrapped.coroutine(runtime=None, x=1)
|
||||||
|
|
||||||
|
mock_session.call_tool.assert_awaited_once_with("act", {"x": 1})
|
||||||
|
assert "unsupported type" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_session_pool_tool_extracts_thread_id():
|
async def test_session_pool_tool_extracts_thread_id():
|
||||||
"""Thread ID is extracted from runtime.config when not in context."""
|
"""Thread ID is extracted from runtime.config when not in context."""
|
||||||
|
|||||||
@@ -30,6 +30,41 @@ class TestValidateUserId:
|
|||||||
paths.user_dir("")
|
paths.user_dir("")
|
||||||
|
|
||||||
|
|
||||||
|
class TestMakeSafeUserId:
|
||||||
|
def test_already_safe_id_is_unchanged(self):
|
||||||
|
from deerflow.config.paths import make_safe_user_id
|
||||||
|
|
||||||
|
assert make_safe_user_id("ou_abc-123") == "ou_abc-123"
|
||||||
|
assert make_safe_user_id("123456") == "123456"
|
||||||
|
|
||||||
|
def test_unsafe_chars_are_sanitized_with_stable_suffix(self):
|
||||||
|
from deerflow.config.paths import make_safe_user_id
|
||||||
|
|
||||||
|
result = make_safe_user_id("user@example.com")
|
||||||
|
# Sanitized prefix plus a stable digest of the original.
|
||||||
|
assert result.startswith("user-example-com-")
|
||||||
|
assert len(result.rsplit("-", 1)[1]) == 16
|
||||||
|
assert make_safe_user_id("user@example.com") == result
|
||||||
|
|
||||||
|
def test_sanitized_id_passes_validation(self, paths: Paths):
|
||||||
|
from deerflow.config.paths import make_safe_user_id
|
||||||
|
|
||||||
|
safe = make_safe_user_id("用户/../etc")
|
||||||
|
# Must be usable as a filesystem-scoped bucket without raising.
|
||||||
|
assert paths.user_dir(safe) == paths.base_dir / "users" / safe
|
||||||
|
|
||||||
|
def test_distinct_unsafe_ids_do_not_collide(self):
|
||||||
|
from deerflow.config.paths import make_safe_user_id
|
||||||
|
|
||||||
|
assert make_safe_user_id("a.b") != make_safe_user_id("a/b")
|
||||||
|
|
||||||
|
def test_empty_id_rejected(self):
|
||||||
|
from deerflow.config.paths import make_safe_user_id
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="non-empty"):
|
||||||
|
make_safe_user_id("")
|
||||||
|
|
||||||
|
|
||||||
class TestUserDir:
|
class TestUserDir:
|
||||||
def test_user_dir(self, paths: Paths):
|
def test_user_dir(self, paths: Paths):
|
||||||
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
|
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
|
||||||
|
|||||||
Reference in New Issue
Block a user