fix(mcp): add auth interceptor with channel user_id and keep header propagation to mcp tools (#3294)

* 修复channel中的user_id传递到interceptor中的bug, mcp可通过header传递user_id到mcp工具

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(channel,mcp,gateway): normalize channel user_id and add regression tests

Normalize external channel user ids into filesystem-safe runtime context while preserving raw channel_user_id, and document gateway user_id propagation semantics. Add regression coverage for channel user_id context mapping, gateway user_id precedence/internal-role behavior, and MCP interceptor header forwarding via meta.headers.

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(auth,mcp): harden user id normalization and header handling

Increase sanitized user-id digest suffix to 16 hex chars, replace internal system role magic string with a shared constant, and harden MCP header forwarding with Mapping type checks. Add regression tests for empty channel user_id handling, unsupported header types, and updated digest length behavior.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: zhongli <335302680@qq.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
zhongli-sz
2026-06-03 15:48:19 +08:00
committed by GitHub
parent 5dc2d6cbf5
commit 3ae82dc663
9 changed files with 309 additions and 4 deletions
+45
View File
@@ -1787,6 +1787,51 @@ class TestChannelManager:
_run(go())
class TestResolveRunParamsUserId:
"""Regression for PR #3294: channel identity must reach ``run_context``
while staying safe for user-scoped filesystem buckets.
"""
def _manager(self):
from app.channels.manager import ChannelManager
bus = MessageBus()
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
return ChannelManager(bus=bus, store=store)
def test_safe_user_id_is_passed_through(self):
manager = self._manager()
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == "123456"
assert run_context["channel_user_id"] == "123456"
def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
from deerflow.config.paths import make_safe_user_id
manager = self._manager()
raw = "user@example.com"
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert run_context["user_id"] == make_safe_user_id(raw)
assert run_context["user_id"] != raw
assert run_context["channel_user_id"] == raw
@pytest.mark.parametrize("raw_user_id", ["", None])
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id):
manager = self._manager()
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
assert "user_id" not in run_context
assert "channel_user_id" not in run_context
# ---------------------------------------------------------------------------
# ChannelService tests
# ---------------------------------------------------------------------------
+43
View File
@@ -431,6 +431,49 @@ def test_inject_authenticated_user_context_overrides_client_user_id():
assert config["context"]["user_id"] == "auth-user-42"
def test_merge_run_context_overrides_propagates_user_id():
"""Regression for PR #3294: ``user_id`` from ``body.context`` must land in
``config['context']`` so non-web callers (e.g. IM channels) keep their identity
on ``ToolRuntime.context``.
"""
from app.gateway.services import build_run_config, merge_run_context_overrides
config = build_run_config("thread-1", None, None)
merge_run_context_overrides(config, {"user_id": "channel-user-7"})
assert config["context"]["user_id"] == "channel-user-7"
def test_merge_run_context_overrides_does_not_clobber_existing_user_id():
"""``merge_run_context_overrides`` must not override an already-stamped
authenticated ``context.user_id`` with the client-supplied value.
"""
from app.gateway.services import build_run_config, merge_run_context_overrides
config = build_run_config("thread-1", {"context": {"user_id": "auth-user-42"}}, None)
merge_run_context_overrides(config, {"user_id": "spoofed-client"})
assert config["context"]["user_id"] == "auth-user-42"
def test_inject_authenticated_user_context_skips_internal_role():
"""Regression for PR #3294: internal system-role callers must not overwrite an
already-present ``context.user_id`` (e.g. a channel-supplied identity), so the
real end user keeps owning the per-user storage bucket.
"""
from types import SimpleNamespace
from app.gateway.services import build_run_config, inject_authenticated_user_context
config = build_run_config("thread-1", None, None)
config["context"] = {"user_id": "channel-user-7"}
request = SimpleNamespace(state=SimpleNamespace(user=SimpleNamespace(id="internal-bot", system_role="internal")))
inject_authenticated_user_context(config, request)
assert config["context"]["user_id"] == "channel-user-7"
# ---------------------------------------------------------------------------
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
# ---------------------------------------------------------------------------
+130
View File
@@ -256,6 +256,136 @@ async def test_session_pool_tool_wrapping():
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
@pytest.mark.asyncio
async def test_session_pool_tool_forwards_interceptor_headers():
"""Regression for PR #3294: when an interceptor sets ``request.headers``, the
pooled stdio call must forward them via ``meta={"headers": ...}`` so downstream
MCP servers can read auth/context headers.
"""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="srv_act",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
async def header_interceptor(request, handler):
return await handler(request.override(headers={"X-User-Id": "u-42"}))
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(
original_tool,
"srv",
{"transport": "stdio", "command": "x", "args": []},
tool_interceptors=[header_interceptor],
)
await wrapped.coroutine(runtime=None, x=1)
mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}, meta={"headers": {"X-User-Id": "u-42"}})
@pytest.mark.asyncio
async def test_session_pool_tool_no_headers_omits_meta():
"""When no interceptor sets headers, the pooled call must not pass a ``meta``
kwarg (falls back to the plain two-argument ``call_tool``).
"""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="srv_act",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
async def passthrough_interceptor(request, handler):
return await handler(request)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(
original_tool,
"srv",
{"transport": "stdio", "command": "x", "args": []},
tool_interceptors=[passthrough_interceptor],
)
await wrapped.coroutine(runtime=None, x=1)
mock_session.call_tool.assert_awaited_once_with("act", {"x": 1})
@pytest.mark.asyncio
async def test_session_pool_tool_ignores_unsupported_header_type(caplog):
"""Defensive path: non-mapping truthy headers should be ignored safely."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
class TruthyHeaders:
def __bool__(self) -> bool:
return True
original_tool = StructuredTool(
name="srv_act",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
async def invalid_header_interceptor(request, handler):
return await handler(request.override(headers=TruthyHeaders()))
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(
original_tool,
"srv",
{"transport": "stdio", "command": "x", "args": []},
tool_interceptors=[invalid_header_interceptor],
)
await wrapped.coroutine(runtime=None, x=1)
mock_session.call_tool.assert_awaited_once_with("act", {"x": 1})
assert "unsupported type" in caplog.text
@pytest.mark.asyncio
async def test_session_pool_tool_extracts_thread_id():
"""Thread ID is extracted from runtime.config when not in context."""
@@ -30,6 +30,41 @@ class TestValidateUserId:
paths.user_dir("")
class TestMakeSafeUserId:
def test_already_safe_id_is_unchanged(self):
from deerflow.config.paths import make_safe_user_id
assert make_safe_user_id("ou_abc-123") == "ou_abc-123"
assert make_safe_user_id("123456") == "123456"
def test_unsafe_chars_are_sanitized_with_stable_suffix(self):
from deerflow.config.paths import make_safe_user_id
result = make_safe_user_id("user@example.com")
# Sanitized prefix plus a stable digest of the original.
assert result.startswith("user-example-com-")
assert len(result.rsplit("-", 1)[1]) == 16
assert make_safe_user_id("user@example.com") == result
def test_sanitized_id_passes_validation(self, paths: Paths):
from deerflow.config.paths import make_safe_user_id
safe = make_safe_user_id("用户/../etc")
# Must be usable as a filesystem-scoped bucket without raising.
assert paths.user_dir(safe) == paths.base_dir / "users" / safe
def test_distinct_unsafe_ids_do_not_collide(self):
from deerflow.config.paths import make_safe_user_id
assert make_safe_user_id("a.b") != make_safe_user_id("a/b")
def test_empty_id_rejected(self):
from deerflow.config.paths import make_safe_user_id
with pytest.raises(ValueError, match="non-empty"):
make_safe_user_id("")
class TestUserDir:
def test_user_dir(self, paths: Paths):
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"