mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
fix(mcp): add auth interceptor with channel user_id and keep header propagation to mcp tools (#3294)
* 修复channel中的user_id传递到interceptor中的bug, mcp可通过header传递user_id到mcp工具 Co-authored-by: Cursor <cursoragent@cursor.com> * fix(channel,mcp,gateway): normalize channel user_id and add regression tests Normalize external channel user ids into filesystem-safe runtime context while preserving raw channel_user_id, and document gateway user_id propagation semantics. Add regression coverage for channel user_id context mapping, gateway user_id precedence/internal-role behavior, and MCP interceptor header forwarding via meta.headers. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(auth,mcp): harden user id normalization and header handling Increase sanitized user-id digest suffix to 16 hex chars, replace internal system role magic string with a shared constant, and harden MCP header forwarding with Mapping type checks. Add regression tests for empty channel user_id handling, unsupported header types, and updated digest length behavior. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: zhongli <335302680@qq.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -256,6 +256,136 @@ async def test_session_pool_tool_wrapping():
|
||||
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_forwards_interceptor_headers():
|
||||
"""Regression for PR #3294: when an interceptor sets ``request.headers``, the
|
||||
pooled stdio call must forward them via ``meta={"headers": ...}`` so downstream
|
||||
MCP servers can read auth/context headers.
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="srv_act",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def header_interceptor(request, handler):
|
||||
return await handler(request.override(headers={"X-User-Id": "u-42"}))
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(
|
||||
original_tool,
|
||||
"srv",
|
||||
{"transport": "stdio", "command": "x", "args": []},
|
||||
tool_interceptors=[header_interceptor],
|
||||
)
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
mock_session.call_tool.assert_awaited_once_with("act", {"x": 1}, meta={"headers": {"X-User-Id": "u-42"}})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_no_headers_omits_meta():
|
||||
"""When no interceptor sets headers, the pooled call must not pass a ``meta``
|
||||
kwarg (falls back to the plain two-argument ``call_tool``).
|
||||
"""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="srv_act",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def passthrough_interceptor(request, handler):
|
||||
return await handler(request)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(
|
||||
original_tool,
|
||||
"srv",
|
||||
{"transport": "stdio", "command": "x", "args": []},
|
||||
tool_interceptors=[passthrough_interceptor],
|
||||
)
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
mock_session.call_tool.assert_awaited_once_with("act", {"x": 1})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_ignores_unsupported_header_type(caplog):
|
||||
"""Defensive path: non-mapping truthy headers should be ignored safely."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
class TruthyHeaders:
|
||||
def __bool__(self) -> bool:
|
||||
return True
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="srv_act",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
async def invalid_header_interceptor(request, handler):
|
||||
return await handler(request.override(headers=TruthyHeaders()))
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(
|
||||
original_tool,
|
||||
"srv",
|
||||
{"transport": "stdio", "command": "x", "args": []},
|
||||
tool_interceptors=[invalid_header_interceptor],
|
||||
)
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
mock_session.call_tool.assert_awaited_once_with("act", {"x": 1})
|
||||
assert "unsupported type" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_extracts_thread_id():
|
||||
"""Thread ID is extracted from runtime.config when not in context."""
|
||||
|
||||
Reference in New Issue
Block a user