fix(mcp): persist MCP sessions across tool calls for stateful servers

MCP tools loaded via langchain-mcp-adapters created a new session on
  every call, causing stateful servers like Playwright to lose browser
  state (pages, forms) between consecutive tool invocations within the
  same thread.

  Add MCPSessionPool that maintains persistent sessions scoped by
  (server_name, thread_id). Tool calls within the same thread now reuse
  the same MCP session, preserving server-side state. Sessions are evicted
  in LRU order (max 256) and cleaned up on cache invalidation.

  Fixes #3054
This commit is contained in:
Willem Jiang
2026-05-20 11:17:58 +08:00
parent c810e9f809
commit 1b88c38d80
4 changed files with 590 additions and 8 deletions
+92 -8
View File
@@ -1,21 +1,83 @@
"""Load MCP tools using langchain-mcp-adapters."""
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
from __future__ import annotations
import logging
from typing import Annotated, Any
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool, InjectedToolArg, StructuredTool
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.mcp.session_pool import get_session_pool
from deerflow.reflection import resolve_variable
from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__)
def _extract_thread_id(runtime: Any) -> str:
"""Extract thread_id from the injected tool runtime."""
if runtime is not None:
ctx = getattr(runtime, "context", None) or {}
tid = ctx.get("thread_id")
if tid is not None:
return str(tid)
config = getattr(runtime, "config", None) or {}
tid = config.get("configurable", {}).get("thread_id")
if tid is not None:
return str(tid)
return "default"
def _make_session_pool_tool(
tool: BaseTool,
server_name: str,
connection: dict[str, Any],
) -> BaseTool:
"""Wrap an MCP tool so it reuses a persistent session from the pool.
Replaces the per-call session creation with pool-managed sessions scoped
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
Playwright) keep their state across tool calls within the same thread.
"""
# Strip the server-name prefix to recover the original MCP tool name.
original_name = tool.name
prefix = f"{server_name}_"
if original_name.startswith(prefix):
original_name = original_name[len(prefix) :]
pool = get_session_pool()
async def call_with_persistent_session(
runtime: Annotated[object | None, InjectedToolArg()] = None,
**arguments: dict[str, Any],
) -> Any:
thread_id = _extract_thread_id(runtime)
session = await pool.get_session(server_name, thread_id, connection)
call_tool_result = await session.call_tool(original_name, arguments)
from langchain_mcp_adapters.tools import _convert_call_tool_result
return _convert_call_tool_result(call_tool_result)
return StructuredTool(
name=tool.name,
description=tool.description,
args_schema=tool.args_schema,
coroutine=call_with_persistent_session,
response_format="content_and_artifact",
metadata=tool.metadata,
)
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
Tools are wrapped with persistent-session logic so that consecutive
calls within the same thread reuse the same MCP session.
Returns:
List of LangChain tools from all enabled MCP servers.
"""
@@ -50,7 +112,7 @@ async def get_mcp_tools() -> list[BaseTool]:
existing_headers["Authorization"] = auth_header
servers_config[server_name]["headers"] = existing_headers
tool_interceptors = []
tool_interceptors: list[Any] = []
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor)
@@ -74,20 +136,42 @@ async def get_mcp_tools() -> list[BaseTool]:
elif interceptor is not None:
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
except Exception as e:
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
logger.warning(
f"Failed to load MCP interceptor {interceptor_path}: {e}",
exc_info=True,
)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
client = MultiServerMCPClient(
servers_config,
tool_interceptors=tool_interceptors,
tool_name_prefix=True,
)
# Get all tools from all servers
# Get all tools from all servers (discovers tool definitions via
# temporary sessions the persistent-session wrapping is applied below).
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Patch tools to support sync invocation, as deerflow client streams synchronously
# Wrap each tool with persistent-session logic.
wrapped_tools: list[BaseTool] = []
for tool in tools:
tool_server: str | None = None
for name in servers_config:
if tool.name.startswith(f"{name}_"):
tool_server = name
break
if tool_server is not None:
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server]))
else:
wrapped_tools.append(tool)
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in wrapped_tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools
return wrapped_tools
except Exception as e:
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)