fix(mcp): persist MCP sessions across tool calls for stateful servers

MCP tools loaded via langchain-mcp-adapters created a new session on
  every call, causing stateful servers like Playwright to lose browser
  state (pages, forms) between consecutive tool invocations within the
  same thread.

  Add MCPSessionPool that maintains persistent sessions scoped by
  (server_name, thread_id). Tool calls within the same thread now reuse
  the same MCP session, preserving server-side state. Sessions are evicted
  in LRU order (max 256) and cleaned up on cache invalidation.

  Fixes #3054
This commit is contained in:
Willem Jiang
2026-05-20 11:17:58 +08:00
parent c810e9f809
commit 1b88c38d80
4 changed files with 590 additions and 8 deletions
@@ -134,9 +134,36 @@ def reset_mcp_tools_cache() -> None:
"""Reset the MCP tools cache.
This is useful for testing or when you want to reload MCP tools.
Also closes all persistent MCP sessions so they are recreated on
the next tool load.
"""
global _mcp_tools_cache, _cache_initialized, _config_mtime
_mcp_tools_cache = None
_cache_initialized = False
_config_mtime = None
# Close persistent sessions they will be recreated by the next
# get_mcp_tools() call with the (possibly updated) connection config.
try:
from deerflow.mcp.session_pool import get_session_pool
pool = get_session_pool()
try:
loop = asyncio.get_event_loop()
if loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, pool.close_all())
future.result()
else:
loop.run_until_complete(pool.close_all())
except RuntimeError:
asyncio.run(pool.close_all())
except Exception:
logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
from deerflow.mcp.session_pool import reset_session_pool
reset_session_pool()
logger.info("MCP tools cache reset")
@@ -0,0 +1,141 @@
"""Persistent MCP session pool for stateful tool calls.
When MCP tools are loaded via langchain-mcp-adapters with ``session=None``,
each tool call creates a new MCP session. For stateful servers like Playwright,
this means browser state (opened pages, filled forms) is lost between calls.
This module provides a session pool that maintains persistent MCP sessions,
scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id —
so that consecutive tool calls share the same session and server-side state.
Sessions are evicted in LRU order when the pool reaches capacity.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from collections import OrderedDict
from typing import Any
from mcp import ClientSession
logger = logging.getLogger(__name__)
class MCPSessionPool:
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
MAX_SESSIONS = 256
def __init__(self) -> None:
self._entries: OrderedDict[
tuple[str, str],
tuple[ClientSession, asyncio.AbstractEventLoop],
] = OrderedDict()
self._context_managers: dict[tuple[str, str], Any] = {}
self._lock = asyncio.Lock()
async def get_session(
self,
server_name: str,
scope_key: str,
connection: dict[str, Any],
) -> ClientSession:
"""Get or create a persistent MCP session.
If an existing session was created in a different event loop (e.g.
the sync-wrapper path), it is closed and replaced with a fresh one
in the current loop.
Args:
server_name: MCP server name.
scope_key: Isolation key (typically thread_id).
connection: Connection configuration for ``create_session``.
Returns:
An initialized ``ClientSession``.
"""
key = (server_name, scope_key)
current_loop = asyncio.get_running_loop()
async with self._lock:
if key in self._entries:
session, loop = self._entries[key]
if loop is current_loop:
self._entries.move_to_end(key)
return session
# Session belongs to a different event loop close it.
await self._close_session(key)
# Evict oldest entries when at capacity.
while len(self._entries) >= self.MAX_SESSIONS:
oldest_key = next(iter(self._entries))
await self._close_session(oldest_key)
from langchain_mcp_adapters.sessions import create_session
cm = create_session(connection)
session = await cm.__aenter__()
await session.initialize()
self._entries[key] = (session, current_loop)
self._context_managers[key] = cm
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
return session
# ------------------------------------------------------------------
# Cleanup helpers
# ------------------------------------------------------------------
async def _close_session(self, key: tuple[str, str]) -> None:
cm = self._context_managers.pop(key, None)
self._entries.pop(key, None)
if cm is not None:
try:
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session %s", key, exc_info=True)
async def close_scope(self, scope_key: str) -> None:
"""Close all sessions for a given scope (e.g. thread_id)."""
async with self._lock:
keys_to_close = [k for k in self._entries if k[1] == scope_key]
for key in keys_to_close:
await self._close_session(key)
async def close_server(self, server_name: str) -> None:
"""Close all sessions for a given server."""
async with self._lock:
keys_to_close = [k for k in self._entries if k[0] == server_name]
for key in keys_to_close:
await self._close_session(key)
async def close_all(self) -> None:
"""Close every managed session."""
async with self._lock:
for key in list(self._context_managers.keys()):
await self._close_session(key)
# ------------------------------------------------------------------
# Module-level singleton
# ------------------------------------------------------------------
_pool: MCPSessionPool | None = None
_pool_lock = threading.Lock()
def get_session_pool() -> MCPSessionPool:
"""Return the global session-pool singleton."""
global _pool
if _pool is None:
with _pool_lock:
if _pool is None:
_pool = MCPSessionPool()
return _pool
def reset_session_pool() -> None:
"""Reset the singleton (for tests)."""
global _pool
_pool = None
+92 -8
View File
@@ -1,21 +1,83 @@
"""Load MCP tools using langchain-mcp-adapters."""
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
from __future__ import annotations
import logging
from typing import Annotated, Any
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool, InjectedToolArg, StructuredTool
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.mcp.session_pool import get_session_pool
from deerflow.reflection import resolve_variable
from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__)
def _extract_thread_id(runtime: Any) -> str:
"""Extract thread_id from the injected tool runtime."""
if runtime is not None:
ctx = getattr(runtime, "context", None) or {}
tid = ctx.get("thread_id")
if tid is not None:
return str(tid)
config = getattr(runtime, "config", None) or {}
tid = config.get("configurable", {}).get("thread_id")
if tid is not None:
return str(tid)
return "default"
def _make_session_pool_tool(
tool: BaseTool,
server_name: str,
connection: dict[str, Any],
) -> BaseTool:
"""Wrap an MCP tool so it reuses a persistent session from the pool.
Replaces the per-call session creation with pool-managed sessions scoped
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
Playwright) keep their state across tool calls within the same thread.
"""
# Strip the server-name prefix to recover the original MCP tool name.
original_name = tool.name
prefix = f"{server_name}_"
if original_name.startswith(prefix):
original_name = original_name[len(prefix) :]
pool = get_session_pool()
async def call_with_persistent_session(
runtime: Annotated[object | None, InjectedToolArg()] = None,
**arguments: dict[str, Any],
) -> Any:
thread_id = _extract_thread_id(runtime)
session = await pool.get_session(server_name, thread_id, connection)
call_tool_result = await session.call_tool(original_name, arguments)
from langchain_mcp_adapters.tools import _convert_call_tool_result
return _convert_call_tool_result(call_tool_result)
return StructuredTool(
name=tool.name,
description=tool.description,
args_schema=tool.args_schema,
coroutine=call_with_persistent_session,
response_format="content_and_artifact",
metadata=tool.metadata,
)
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
Tools are wrapped with persistent-session logic so that consecutive
calls within the same thread reuse the same MCP session.
Returns:
List of LangChain tools from all enabled MCP servers.
"""
@@ -50,7 +112,7 @@ async def get_mcp_tools() -> list[BaseTool]:
existing_headers["Authorization"] = auth_header
servers_config[server_name]["headers"] = existing_headers
tool_interceptors = []
tool_interceptors: list[Any] = []
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor)
@@ -74,20 +136,42 @@ async def get_mcp_tools() -> list[BaseTool]:
elif interceptor is not None:
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
except Exception as e:
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
logger.warning(
f"Failed to load MCP interceptor {interceptor_path}: {e}",
exc_info=True,
)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
client = MultiServerMCPClient(
servers_config,
tool_interceptors=tool_interceptors,
tool_name_prefix=True,
)
# Get all tools from all servers
# Get all tools from all servers (discovers tool definitions via
# temporary sessions the persistent-session wrapping is applied below).
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Patch tools to support sync invocation, as deerflow client streams synchronously
# Wrap each tool with persistent-session logic.
wrapped_tools: list[BaseTool] = []
for tool in tools:
tool_server: str | None = None
for name in servers_config:
if tool.name.startswith(f"{name}_"):
tool_server = name
break
if tool_server is not None:
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server]))
else:
wrapped_tools.append(tool)
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in wrapped_tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools
return wrapped_tools
except Exception as e:
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
+330
View File
@@ -0,0 +1,330 @@
"""Tests for the MCP persistent-session pool."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool
@pytest.fixture(autouse=True)
def _reset_pool():
reset_session_pool()
yield
reset_session_pool()
# ---------------------------------------------------------------------------
# MCPSessionPool unit tests
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_get_session_creates_new():
"""First call for a key creates a new session."""
pool = MCPSessionPool()
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
assert session is mock_session
mock_session.initialize.assert_awaited_once()
@pytest.mark.asyncio
async def test_get_session_reuses_existing():
"""Second call for the same key returns the cached session."""
pool = MCPSessionPool()
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
assert s1 is s2
# Only one session should have been created.
assert mock_cm.__aenter__.await_count == 1
@pytest.mark.asyncio
async def test_different_scope_creates_different_session():
"""Different scope keys get different sessions."""
pool = MCPSessionPool()
sessions = [AsyncMock(), AsyncMock()]
idx = 0
class CmFactory:
def __init__(self):
self.enter_count = 0
async def __aenter__(self):
nonlocal idx
s = sessions[idx]
idx += 1
self.enter_count += 1
return s
async def __aexit__(self, *args):
return False
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()):
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []})
assert s1 is not s2
assert s1 is sessions[0]
assert s2 is sessions[1]
@pytest.mark.asyncio
async def test_lru_eviction():
"""Oldest entries are evicted when the pool is full."""
pool = MCPSessionPool()
pool.MAX_SESSIONS = 2
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
# Pool is full (2). Adding t3 should evict t1.
await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []})
assert cms[0].closed is True
assert cms[1].closed is False
assert cms[2].closed is False
@pytest.mark.asyncio
async def test_close_scope():
"""close_scope shuts down sessions for a specific scope key."""
pool = MCPSessionPool()
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
await pool.close_scope("t1")
assert cms[0].closed is True
assert cms[1].closed is False
# t2 session still exists.
assert ("s", "t2") in pool._entries
@pytest.mark.asyncio
async def test_close_all():
"""close_all shuts down every session."""
pool = MCPSessionPool()
class CmFactory:
def __init__(self):
self.closed = False
async def __aenter__(self):
return AsyncMock()
async def __aexit__(self, *args):
self.closed = True
return False
cms: list[CmFactory] = []
def make_cm(*a, **kw):
cm = CmFactory()
cms.append(cm)
return cm
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []})
await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []})
await pool.close_all()
assert all(cm.closed for cm in cms)
assert len(pool._entries) == 0
# ---------------------------------------------------------------------------
# Singleton helpers
# ---------------------------------------------------------------------------
def test_get_session_pool_singleton():
"""get_session_pool returns the same instance."""
p1 = get_session_pool()
p2 = get_session_pool()
assert p1 is p2
def test_reset_session_pool():
"""reset_session_pool clears the singleton."""
p1 = get_session_pool()
reset_session_pool()
p2 = get_session_pool()
assert p1 is not p2
# ---------------------------------------------------------------------------
# Integration: _make_session_pool_tool uses the pool
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_session_pool_tool_wrapping():
"""The wrapper tool delegates to a pool-managed session."""
# Build a dummy StructuredTool (as returned by langchain-mcp-adapters).
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
url: str = Field(..., description="url")
original_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
connection = {"transport": "stdio", "command": "pw", "args": []}
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
# Simulate a tool call with a runtime context containing thread_id.
mock_runtime = MagicMock()
mock_runtime.context = {"thread_id": "thread-42"}
mock_runtime.config = {}
await wrapped.coroutine(runtime=mock_runtime, url="https://example.com")
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
@pytest.mark.asyncio
async def test_session_pool_tool_extracts_thread_id():
"""Thread ID is extracted from runtime.config when not in context."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
mock_runtime = MagicMock()
mock_runtime.context = {}
mock_runtime.config = {"configurable": {"thread_id": "from-config"}}
await wrapped.coroutine(runtime=mock_runtime, x=1)
# Verify the session was created with the correct scope key.
pool = get_session_pool()
assert ("server", "from-config") in pool._entries
@pytest.mark.asyncio
async def test_session_pool_tool_default_scope():
"""When no thread_id is available, 'default' is used as scope key."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import _make_session_pool_tool
class Args(BaseModel):
x: int = Field(..., description="x")
original_tool = StructuredTool(
name="server_tool",
description="test",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
# No thread_id in runtime at all.
await wrapped.coroutine(runtime=None, x=1)
pool = get_session_pool()
assert ("server", "default") in pool._entries