mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-26 18:06:00 +00:00
Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| ab2c7d07a5 | |||
| edeaa84563 | |||
| 2d84ddb1ae |
@@ -1,4 +1,4 @@
|
||||
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
|
||||
"""Load MCP tools using langchain-mcp-adapters with stdio session pooling."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -173,8 +173,10 @@ def _make_session_pool_tool(
|
||||
async def get_mcp_tools() -> list[BaseTool]:
|
||||
"""Get all tools from enabled MCP servers.
|
||||
|
||||
Tools are wrapped with persistent-session logic so that consecutive
|
||||
calls within the same thread reuse the same MCP session.
|
||||
Tools using stdio transport are wrapped with persistent-session logic so
|
||||
consecutive calls within the same thread reuse the same MCP session.
|
||||
HTTP/SSE tools are returned unwrapped to avoid cross-task TaskGroup
|
||||
cleanup errors.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
@@ -251,6 +253,9 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||||
|
||||
# Wrap each tool with persistent-session logic.
|
||||
# Only pool stdio sessions. HTTP/SSE transports use anyio TaskGroups
|
||||
# internally which cannot be closed from a different async task, so
|
||||
# pooling them causes RuntimeError on cleanup (see #3203).
|
||||
wrapped_tools: list[BaseTool] = []
|
||||
for tool in tools:
|
||||
tool_server: str | None = None
|
||||
@@ -260,9 +265,13 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
break
|
||||
|
||||
if tool_server is not None:
|
||||
transport = servers_config[tool_server].get("transport", "stdio")
|
||||
if transport == "stdio":
|
||||
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
|
||||
else:
|
||||
wrapped_tools.append(tool)
|
||||
else:
|
||||
wrapped_tools.append(tool)
|
||||
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
for tool in wrapped_tools:
|
||||
|
||||
@@ -407,3 +407,80 @@ def test_session_pool_tool_sync_wrapper_path_is_safe():
|
||||
wrapped.func(url="https://example.com")
|
||||
|
||||
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_mcp_tools: HTTP transport should NOT be pooled
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_http_transport_tools_not_pooled():
|
||||
"""HTTP/SSE transport tools should NOT be wrapped with the session pool."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import get_mcp_tools
|
||||
|
||||
class Args(BaseModel):
|
||||
query: str = Field(..., description="query")
|
||||
|
||||
http_tool = StructuredTool(
|
||||
name="myserver_search",
|
||||
description="Search tool",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
stdio_tool = StructuredTool(
|
||||
name="playwright_navigate",
|
||||
description="Navigate browser",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
extensions_config = MagicMock()
|
||||
extensions_config.get_enabled_mcp_servers.return_value = {
|
||||
"myserver": MagicMock(type="http", url="http://localhost:8000/mcp", headers=None, command=None, args=[], env=None),
|
||||
"playwright": MagicMock(type="stdio", command="npx", args=["-y", "@anthropic/mcp-server-playwright"], env=None, url=None, headers=None),
|
||||
}
|
||||
extensions_config.model_extra = {}
|
||||
|
||||
servers_config = {
|
||||
"myserver": {"transport": "http", "url": "http://localhost:8000/mcp"},
|
||||
"playwright": {"transport": "stdio", "command": "npx", "args": ["-y", "@anthropic/mcp-server-playwright"]},
|
||||
}
|
||||
|
||||
with (
|
||||
patch("deerflow.mcp.tools.ExtensionsConfig.from_file", return_value=extensions_config),
|
||||
patch("deerflow.mcp.tools.build_servers_config", return_value=servers_config),
|
||||
patch("deerflow.mcp.tools.get_initial_oauth_headers", return_value={}),
|
||||
patch("deerflow.mcp.tools.build_oauth_tool_interceptor", return_value=None),
|
||||
patch("langchain_mcp_adapters.client.MultiServerMCPClient") as MockClient,
|
||||
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
|
||||
):
|
||||
mock_client_instance = MockClient.return_value
|
||||
mock_client_instance.get_tools = AsyncMock(return_value=[http_tool, stdio_tool])
|
||||
|
||||
tools = await get_mcp_tools()
|
||||
|
||||
pool = get_session_pool()
|
||||
# Tool discovery is lazy: no pooled sessions are created until a wrapped tool is invoked.
|
||||
assert list(pool._entries.keys()) == []
|
||||
|
||||
# Verify the HTTP tool was NOT wrapped with the pool (it's the original tool).
|
||||
http_tools = [t for t in tools if t.name == "myserver_search"]
|
||||
assert len(http_tools) == 1
|
||||
assert http_tools[0].coroutine is http_tool.coroutine
|
||||
|
||||
# Verify the stdio tool WAS wrapped with the pool.
|
||||
stdio_tools = [t for t in tools if t.name == "playwright_navigate"]
|
||||
assert len(stdio_tools) == 1
|
||||
assert stdio_tools[0].coroutine is not stdio_tool.coroutine
|
||||
|
||||
Reference in New Issue
Block a user