fix(mcp): skip session pooling for HTTP/SSE transports to avoid anyio RuntimeError (#3203)

HTTP/SSE transports use anyio.TaskGroup internally for streamable
  connections. These task groups have cancel scopes bound to the async task
  that created them, so closing a pooled session from a different task
  raises RuntimeError. Restrict session pooling to stdio transports only.
This commit is contained in:
Willem Jiang
2026-05-26 09:15:21 +08:00
parent f9b7071304
commit 2d84ddb1ae
2 changed files with 81 additions and 1 deletions
@@ -251,6 +251,9 @@ async def get_mcp_tools() -> list[BaseTool]:
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers") logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Wrap each tool with persistent-session logic. # Wrap each tool with persistent-session logic.
# Only pool stdio sessions. HTTP/SSE transports use anyio TaskGroups
# internally which cannot be closed from a different async task, so
# pooling them causes RuntimeError on cleanup (see #3203).
wrapped_tools: list[BaseTool] = [] wrapped_tools: list[BaseTool] = []
for tool in tools: for tool in tools:
tool_server: str | None = None tool_server: str | None = None
@@ -260,7 +263,11 @@ async def get_mcp_tools() -> list[BaseTool]:
break break
if tool_server is not None: if tool_server is not None:
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors)) transport = servers_config[tool_server].get("transport", "stdio")
if transport == "stdio":
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
else:
wrapped_tools.append(tool)
else: else:
wrapped_tools.append(tool) wrapped_tools.append(tool)
+73
View File
@@ -407,3 +407,76 @@ def test_session_pool_tool_sync_wrapper_path_is_safe():
wrapped.func(url="https://example.com") wrapped.func(url="https://example.com")
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"}) mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
# ---------------------------------------------------------------------------
# get_mcp_tools: HTTP transport should NOT be pooled
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_http_transport_tools_not_pooled():
"""HTTP/SSE transport tools should NOT be wrapped with the session pool."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import get_mcp_tools
class Args(BaseModel):
query: str = Field(..., description="query")
http_tool = StructuredTool(
name="myserver_search",
description="Search tool",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
stdio_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
extensions_config = MagicMock()
extensions_config.get_enabled_mcp_servers.return_value = {
"myserver": MagicMock(type="http", url="http://localhost:8000/mcp", headers=None, command=None, args=[], env=None),
"playwright": MagicMock(type="stdio", command="npx", args=["-y", "@anthropic/mcp-server-playwright"], env=None, url=None, headers=None),
}
extensions_config.model_extra = {}
servers_config = {
"myserver": {"transport": "http", "url": "http://localhost:8000/mcp"},
"playwright": {"transport": "stdio", "command": "npx", "args": ["-y", "@anthropic/mcp-server-playwright"]},
}
with (
patch("deerflow.mcp.tools.ExtensionsConfig.from_file", return_value=extensions_config),
patch("deerflow.mcp.tools.build_servers_config", return_value=servers_config),
patch("deerflow.mcp.tools.get_initial_oauth_headers", return_value={}),
patch("deerflow.mcp.tools.build_oauth_tool_interceptor", return_value=None),
patch("langchain_mcp_adapters.client.MultiServerMCPClient") as MockClient,
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
):
mock_client_instance = MockClient.return_value
mock_client_instance.get_tools = AsyncMock(return_value=[http_tool, stdio_tool])
tools = await get_mcp_tools()
pool = get_session_pool()
# Only the stdio (playwright) tool should have a pool entry; HTTP should not.
pool_keys = list(pool._entries.keys())
assert ("playwright", "default") not in pool_keys # Not called yet, no entry
# Verify the HTTP tool was NOT wrapped with the pool (it's the original tool).
http_tools = [t for t in tools if t.name == "myserver_search"]
assert len(http_tools) == 1
# The HTTP tool's coroutine should be the original, not the pool wrapper.
assert http_tools[0].coroutine is http_tool.coroutine