mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
1b88c38d80
MCP tools loaded via langchain-mcp-adapters created a new session on every call, causing stateful servers like Playwright to lose browser state (pages, forms) between consecutive tool invocations within the same thread. Add MCPSessionPool that maintains persistent sessions scoped by (server_name, thread_id). Tool calls within the same thread now reuse the same MCP session, preserving server-side state. Sessions are evicted in LRU order (max 256) and cleaned up on cache invalidation. Fixes #3054
179 lines
7.2 KiB
Python
179 lines
7.2 KiB
Python
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
from typing import Annotated, Any
|
||
|
||
from langchain_core.tools import BaseTool, InjectedToolArg, StructuredTool
|
||
|
||
from deerflow.config.extensions_config import ExtensionsConfig
|
||
from deerflow.mcp.client import build_servers_config
|
||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||
from deerflow.mcp.session_pool import get_session_pool
|
||
from deerflow.reflection import resolve_variable
|
||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def _extract_thread_id(runtime: Any) -> str:
|
||
"""Extract thread_id from the injected tool runtime."""
|
||
if runtime is not None:
|
||
ctx = getattr(runtime, "context", None) or {}
|
||
tid = ctx.get("thread_id")
|
||
if tid is not None:
|
||
return str(tid)
|
||
config = getattr(runtime, "config", None) or {}
|
||
tid = config.get("configurable", {}).get("thread_id")
|
||
if tid is not None:
|
||
return str(tid)
|
||
return "default"
|
||
|
||
|
||
def _make_session_pool_tool(
|
||
tool: BaseTool,
|
||
server_name: str,
|
||
connection: dict[str, Any],
|
||
) -> BaseTool:
|
||
"""Wrap an MCP tool so it reuses a persistent session from the pool.
|
||
|
||
Replaces the per-call session creation with pool-managed sessions scoped
|
||
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
|
||
Playwright) keep their state across tool calls within the same thread.
|
||
"""
|
||
# Strip the server-name prefix to recover the original MCP tool name.
|
||
original_name = tool.name
|
||
prefix = f"{server_name}_"
|
||
if original_name.startswith(prefix):
|
||
original_name = original_name[len(prefix) :]
|
||
|
||
pool = get_session_pool()
|
||
|
||
async def call_with_persistent_session(
|
||
runtime: Annotated[object | None, InjectedToolArg()] = None,
|
||
**arguments: dict[str, Any],
|
||
) -> Any:
|
||
thread_id = _extract_thread_id(runtime)
|
||
session = await pool.get_session(server_name, thread_id, connection)
|
||
call_tool_result = await session.call_tool(original_name, arguments)
|
||
|
||
from langchain_mcp_adapters.tools import _convert_call_tool_result
|
||
|
||
return _convert_call_tool_result(call_tool_result)
|
||
|
||
return StructuredTool(
|
||
name=tool.name,
|
||
description=tool.description,
|
||
args_schema=tool.args_schema,
|
||
coroutine=call_with_persistent_session,
|
||
response_format="content_and_artifact",
|
||
metadata=tool.metadata,
|
||
)
|
||
|
||
|
||
async def get_mcp_tools() -> list[BaseTool]:
|
||
"""Get all tools from enabled MCP servers.
|
||
|
||
Tools are wrapped with persistent-session logic so that consecutive
|
||
calls within the same thread reuse the same MCP session.
|
||
|
||
Returns:
|
||
List of LangChain tools from all enabled MCP servers.
|
||
"""
|
||
try:
|
||
from langchain_mcp_adapters.client import MultiServerMCPClient
|
||
except ImportError:
|
||
logger.warning("langchain-mcp-adapters not installed. Install it to enable MCP tools: pip install langchain-mcp-adapters")
|
||
return []
|
||
|
||
# NOTE: We use ExtensionsConfig.from_file() instead of get_extensions_config()
|
||
# to always read the latest configuration from disk. This ensures that changes
|
||
# made through the Gateway API (which runs in a separate process) are immediately
|
||
# reflected when initializing MCP tools.
|
||
extensions_config = ExtensionsConfig.from_file()
|
||
servers_config = build_servers_config(extensions_config)
|
||
|
||
if not servers_config:
|
||
logger.info("No enabled MCP servers configured")
|
||
return []
|
||
|
||
try:
|
||
# Create the multi-server MCP client
|
||
logger.info(f"Initializing MCP client with {len(servers_config)} server(s)")
|
||
|
||
# Inject initial OAuth headers for server connections (tool discovery/session init)
|
||
initial_oauth_headers = await get_initial_oauth_headers(extensions_config)
|
||
for server_name, auth_header in initial_oauth_headers.items():
|
||
if server_name not in servers_config:
|
||
continue
|
||
if servers_config[server_name].get("transport") in ("sse", "http"):
|
||
existing_headers = dict(servers_config[server_name].get("headers", {}))
|
||
existing_headers["Authorization"] = auth_header
|
||
servers_config[server_name]["headers"] = existing_headers
|
||
|
||
tool_interceptors: list[Any] = []
|
||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||
if oauth_interceptor is not None:
|
||
tool_interceptors.append(oauth_interceptor)
|
||
|
||
# Load custom interceptors declared in extensions_config.json
|
||
# Format: "mcpInterceptors": ["pkg.module:builder_func", ...]
|
||
raw_interceptor_paths = (extensions_config.model_extra or {}).get("mcpInterceptors")
|
||
if isinstance(raw_interceptor_paths, str):
|
||
raw_interceptor_paths = [raw_interceptor_paths]
|
||
elif not isinstance(raw_interceptor_paths, list):
|
||
if raw_interceptor_paths is not None:
|
||
logger.warning(f"mcpInterceptors must be a list of strings, got {type(raw_interceptor_paths).__name__}; skipping")
|
||
raw_interceptor_paths = []
|
||
for interceptor_path in raw_interceptor_paths:
|
||
try:
|
||
builder = resolve_variable(interceptor_path)
|
||
interceptor = builder()
|
||
if callable(interceptor):
|
||
tool_interceptors.append(interceptor)
|
||
logger.info(f"Loaded MCP interceptor: {interceptor_path}")
|
||
elif interceptor is not None:
|
||
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
|
||
except Exception as e:
|
||
logger.warning(
|
||
f"Failed to load MCP interceptor {interceptor_path}: {e}",
|
||
exc_info=True,
|
||
)
|
||
|
||
client = MultiServerMCPClient(
|
||
servers_config,
|
||
tool_interceptors=tool_interceptors,
|
||
tool_name_prefix=True,
|
||
)
|
||
|
||
# Get all tools from all servers (discovers tool definitions via
|
||
# temporary sessions – the persistent-session wrapping is applied below).
|
||
tools = await client.get_tools()
|
||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||
|
||
# Wrap each tool with persistent-session logic.
|
||
wrapped_tools: list[BaseTool] = []
|
||
for tool in tools:
|
||
tool_server: str | None = None
|
||
for name in servers_config:
|
||
if tool.name.startswith(f"{name}_"):
|
||
tool_server = name
|
||
break
|
||
|
||
if tool_server is not None:
|
||
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server]))
|
||
else:
|
||
wrapped_tools.append(tool)
|
||
|
||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||
for tool in wrapped_tools:
|
||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||
|
||
return wrapped_tools
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
|
||
return []
|