fix(mcp): persist MCP sessions across tool calls for stateful servers (#3089)

* fix(mcp): persist MCP sessions across tool calls for stateful servers

  MCP tools loaded via langchain-mcp-adapters created a new session on
  every call, causing stateful servers like Playwright to lose browser
  state (pages, forms) between consecutive tool invocations within the
  same thread.

  Add MCPSessionPool that maintains persistent sessions scoped by
  (server_name, thread_id). Tool calls within the same thread now reuse
  the same MCP session, preserving server-side state. Sessions are evicted
  in LRU order (max 256) and cleaned up on cache invalidation.

  Fixes #3054

* fix(sandbox): add group/other read permissions to uploaded files for Docker sandbox (#3127)

  When using AIO sandbox with LocalContainerBackend, uploaded files are
  created with 0o600 (owner-only) permissions by the gateway process
  running as root. The sandbox process inside the Docker container runs
  as a non-root user and cannot read these bind-mounted files, causing
  a "Permission denied" error on read_file.

  Add `needs_upload_permission_adjustment` attribute to SandboxProvider
  (default True) to indicate that uploaded files need chmod adjustment.
  LocalSandboxProvider opts out (same user). A new `_make_file_sandbox_readable`
  function adds S_IRGRP | S_IROTH bits after files are written, changing
  permissions from 0o600 to 0o644 so the sandbox can read the uploads.

* fix(mcp): address review comments on session pool and tools

- _extract_thread_id: return "default" instead of stringifying None
  when get_config() returns no thread_id
- call_with_persistent_session: fix **arguments annotation from
  dict[str,Any] to Any
- Replace private _convert_call_tool_result import with a local
  implementation that handles all MCP content block types
- _make_session_pool_tool: accept tool_interceptors and apply the
  configured interceptor chain on every call (preserving OAuth and
  custom interceptors)
- MCPSessionPool: replace asyncio.Lock with threading.Lock; restructure
  get/close methods to never await while holding the lock; add
  close_all_sync() that closes sessions on their owning event loops
- reset_mcp_tools_cache: use pool.close_all_sync() instead of
  asyncio.run-in-thread to close sessions deterministically
- test: add test_session_pool_tool_sync_wrapper_path_is_safe covering
  tool invocation via the sync wrapper (tool.func) path

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/9e7f9e7f-1d2b-464a-b3b7-7f1649b74122

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* fix(mcp): extract SESSION_CLOSE_TIMEOUT to class constant

Agent-Logs-Url: https://github.com/bytedance/deer-flow/sessions/9e7f9e7f-1d2b-464a-b3b7-7f1649b74122

Co-authored-by: WillemJiang <219644+WillemJiang@users.noreply.github.com>

* Potential fix for pull request finding 'Empty except'

Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <223894421+github-code-quality[bot]@users.noreply.github.com>
This commit is contained in:
Willem Jiang
2026-05-21 23:22:20 +08:00
committed by GitHub
parent e93f658472
commit c881d95898
4 changed files with 813 additions and 8 deletions
@@ -134,9 +134,25 @@ def reset_mcp_tools_cache() -> None:
"""Reset the MCP tools cache.
This is useful for testing or when you want to reload MCP tools.
Also closes all persistent MCP sessions so they are recreated on
the next tool load.
"""
global _mcp_tools_cache, _cache_initialized, _config_mtime
_mcp_tools_cache = None
_cache_initialized = False
_config_mtime = None
# Close persistent sessions they will be recreated by the next
# get_mcp_tools() call with the (possibly updated) connection config.
try:
from deerflow.mcp.session_pool import get_session_pool
pool = get_session_pool()
pool.close_all_sync()
except Exception:
logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
from deerflow.mcp.session_pool import reset_session_pool
reset_session_pool()
logger.info("MCP tools cache reset")
@@ -0,0 +1,198 @@
"""Persistent MCP session pool for stateful tool calls.
When MCP tools are loaded via langchain-mcp-adapters with ``session=None``,
each tool call creates a new MCP session. For stateful servers like Playwright,
this means browser state (opened pages, filled forms) is lost between calls.
This module provides a session pool that maintains persistent MCP sessions,
scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id —
so that consecutive tool calls share the same session and server-side state.
Sessions are evicted in LRU order when the pool reaches capacity.
"""
from __future__ import annotations
import asyncio
import logging
import threading
from collections import OrderedDict
from typing import Any
from mcp import ClientSession
logger = logging.getLogger(__name__)
class MCPSessionPool:
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
MAX_SESSIONS = 256
SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe
def __init__(self) -> None:
self._entries: OrderedDict[
tuple[str, str],
tuple[ClientSession, asyncio.AbstractEventLoop],
] = OrderedDict()
self._context_managers: dict[tuple[str, str], Any] = {}
# threading.Lock is not bound to any event loop, so it is safe to
# acquire from both async paths and sync/worker-thread paths.
self._lock = threading.Lock()
async def get_session(
self,
server_name: str,
scope_key: str,
connection: dict[str, Any],
) -> ClientSession:
"""Get or create a persistent MCP session.
If an existing session was created in a different event loop (e.g.
the sync-wrapper path), it is closed and replaced with a fresh one
in the current loop.
Args:
server_name: MCP server name.
scope_key: Isolation key (typically thread_id).
connection: Connection configuration for ``create_session``.
Returns:
An initialized ``ClientSession``.
"""
key = (server_name, scope_key)
current_loop = asyncio.get_running_loop()
# Phase 1: inspect/mutate the registry under the thread lock (no awaits).
cms_to_close: list[tuple[tuple[str, str], Any]] = []
with self._lock:
if key in self._entries:
session, loop = self._entries[key]
if loop is current_loop:
self._entries.move_to_end(key)
return session
# Session belongs to a different event loop evict it.
cm = self._context_managers.pop(key, None)
self._entries.pop(key)
if cm is not None:
cms_to_close.append((key, cm))
# Evict LRU entries when at capacity.
while len(self._entries) >= self.MAX_SESSIONS:
oldest_key = next(iter(self._entries))
cm = self._context_managers.pop(oldest_key, None)
self._entries.pop(oldest_key)
if cm is not None:
cms_to_close.append((oldest_key, cm))
# Phase 2: async cleanup outside the lock so we never await while holding it.
for close_key, cm in cms_to_close:
try:
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session %s", close_key, exc_info=True)
from langchain_mcp_adapters.sessions import create_session
cm = create_session(connection)
session = await cm.__aenter__()
await session.initialize()
# Phase 3: register the new session under the lock.
with self._lock:
self._entries[key] = (session, current_loop)
self._context_managers[key] = cm
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
return session
# ------------------------------------------------------------------
# Cleanup helpers
# ------------------------------------------------------------------
async def _close_cm(self, key: tuple[str, str], cm: Any) -> None:
"""Close a single context manager (must be called WITHOUT the lock)."""
try:
await cm.__aexit__(None, None, None)
except Exception:
logger.warning("Error closing MCP session %s", key, exc_info=True)
async def close_scope(self, scope_key: str) -> None:
"""Close all sessions for a given scope (e.g. thread_id)."""
with self._lock:
keys = [k for k in self._entries if k[1] == scope_key]
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
for k in keys:
self._entries.pop(k, None)
for key, cm in cms:
if cm is not None:
await self._close_cm(key, cm)
async def close_server(self, server_name: str) -> None:
"""Close all sessions for a given server."""
with self._lock:
keys = [k for k in self._entries if k[0] == server_name]
cms = [(k, self._context_managers.pop(k, None)) for k in keys]
for k in keys:
self._entries.pop(k, None)
for key, cm in cms:
if cm is not None:
await self._close_cm(key, cm)
async def close_all(self) -> None:
"""Close every managed session."""
with self._lock:
cms = list(self._context_managers.items())
self._context_managers.clear()
self._entries.clear()
for key, cm in cms:
await self._close_cm(key, cm)
def close_all_sync(self) -> None:
"""Close all sessions using their owning event loops (synchronous).
Each session is closed on the loop it was created in, avoiding
cross-loop resource leaks. Safe to call from any thread without an
active event loop.
"""
with self._lock:
entries = list(self._entries.items())
cms = dict(self._context_managers)
self._entries.clear()
self._context_managers.clear()
for key, (_, loop) in entries:
cm = cms.get(key)
if cm is None or loop.is_closed():
continue
try:
if loop.is_running():
# Schedule on the owning loop from this (different) thread.
future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop)
future.result(timeout=self.SESSION_CLOSE_TIMEOUT)
else:
loop.run_until_complete(cm.__aexit__(None, None, None))
except Exception:
logger.debug("Error closing MCP session %s during sync close", key, exc_info=True)
# ------------------------------------------------------------------
# Module-level singleton
# ------------------------------------------------------------------
_pool: MCPSessionPool | None = None
_pool_lock = threading.Lock()
def get_session_pool() -> MCPSessionPool:
"""Return the global session-pool singleton."""
global _pool
if _pool is None:
with _pool_lock:
if _pool is None:
_pool = MCPSessionPool()
return _pool
def reset_session_pool() -> None:
"""Reset the singleton (for tests)."""
global _pool
_pool = None
+190 -8
View File
@@ -1,21 +1,181 @@
"""Load MCP tools using langchain-mcp-adapters."""
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
from __future__ import annotations
import logging
from typing import Any
from langchain_core.tools import BaseTool
from langchain_core.tools import BaseTool, StructuredTool
from langgraph.config import get_config
from deerflow.config.extensions_config import ExtensionsConfig
from deerflow.mcp.client import build_servers_config
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
from deerflow.mcp.session_pool import get_session_pool
from deerflow.reflection import resolve_variable
from deerflow.tools.sync import make_sync_tool_wrapper
from deerflow.tools.types import Runtime
logger = logging.getLogger(__name__)
def _extract_thread_id(runtime: Runtime | None) -> str:
"""Extract thread_id from the injected tool runtime or LangGraph config."""
if runtime is not None:
tid = runtime.context.get("thread_id") if runtime.context else None
if tid is not None:
return str(tid)
config = runtime.config or {}
tid = config.get("configurable", {}).get("thread_id")
if tid is not None:
return str(tid)
try:
tid = get_config().get("configurable", {}).get("thread_id")
return str(tid) if tid is not None else "default"
except RuntimeError:
return "default"
def _convert_call_tool_result(call_tool_result: Any) -> Any:
"""Convert an MCP CallToolResult to the LangChain ``content_and_artifact`` format.
Implements the same conversion logic as the adapter without relying on
the private ``langchain_mcp_adapters.tools._convert_call_tool_result`` symbol.
"""
from langchain_core.messages import ToolMessage
from langchain_core.messages.content import create_file_block, create_image_block, create_text_block
from langchain_core.tools import ToolException
from mcp.types import EmbeddedResource, ImageContent, ResourceLink, TextContent, TextResourceContents
# Pass ToolMessage through directly (interceptor short-circuit).
if isinstance(call_tool_result, ToolMessage):
return call_tool_result, None
# Pass LangGraph Command through directly when langgraph is installed.
try:
from langgraph.types import Command
if isinstance(call_tool_result, Command):
return call_tool_result, None
except ImportError:
# langgraph is optional; if unavailable, continue with standard MCP content conversion.
pass
# Convert MCP content blocks to LangChain content blocks.
lc_content = []
for item in call_tool_result.content:
if isinstance(item, TextContent):
lc_content.append(create_text_block(text=item.text))
elif isinstance(item, ImageContent):
lc_content.append(create_image_block(base64=item.data, mime_type=item.mimeType))
elif isinstance(item, ResourceLink):
mime = item.mimeType or None
if mime and mime.startswith("image/"):
lc_content.append(create_image_block(url=str(item.uri), mime_type=mime))
else:
lc_content.append(create_file_block(url=str(item.uri), mime_type=mime))
elif isinstance(item, EmbeddedResource):
from mcp.types import BlobResourceContents
res = item.resource
if isinstance(res, TextResourceContents):
lc_content.append(create_text_block(text=res.text))
elif isinstance(res, BlobResourceContents):
mime = res.mimeType or None
if mime and mime.startswith("image/"):
lc_content.append(create_image_block(base64=res.blob, mime_type=mime))
else:
lc_content.append(create_file_block(base64=res.blob, mime_type=mime))
else:
lc_content.append(create_text_block(text=str(res)))
else:
lc_content.append(create_text_block(text=str(item)))
if call_tool_result.isError:
error_parts = [item["text"] for item in lc_content if isinstance(item, dict) and item.get("type") == "text"]
raise ToolException("\n".join(error_parts) if error_parts else str(lc_content))
artifact = None
if call_tool_result.structuredContent is not None:
artifact = {"structured_content": call_tool_result.structuredContent}
return lc_content, artifact
def _make_session_pool_tool(
tool: BaseTool,
server_name: str,
connection: dict[str, Any],
tool_interceptors: list[Any] | None = None,
) -> BaseTool:
"""Wrap an MCP tool so it reuses a persistent session from the pool.
Replaces the per-call session creation with pool-managed sessions scoped
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
Playwright) keep their state across tool calls within the same thread.
The configured ``tool_interceptors`` (OAuth, custom) are preserved and
applied on every call before invoking the pooled session.
"""
# Strip the server-name prefix to recover the original MCP tool name.
original_name = tool.name
prefix = f"{server_name}_"
if original_name.startswith(prefix):
original_name = original_name[len(prefix) :]
pool = get_session_pool()
async def call_with_persistent_session(
runtime: Runtime | None = None,
**arguments: Any,
) -> Any:
thread_id = _extract_thread_id(runtime)
session = await pool.get_session(server_name, thread_id, connection)
if tool_interceptors:
from langchain_mcp_adapters.interceptors import MCPToolCallRequest
async def base_handler(request: MCPToolCallRequest) -> Any:
return await session.call_tool(request.name, request.args)
handler = base_handler
for interceptor in reversed(tool_interceptors):
outer = handler
async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any:
return await _i(req, _h)
handler = wrapped
request = MCPToolCallRequest(
name=original_name,
args=arguments,
server_name=server_name,
runtime=runtime,
)
call_tool_result = await handler(request)
else:
call_tool_result = await session.call_tool(original_name, arguments)
return _convert_call_tool_result(call_tool_result)
return StructuredTool(
name=tool.name,
description=tool.description,
args_schema=tool.args_schema,
coroutine=call_with_persistent_session,
response_format="content_and_artifact",
metadata=tool.metadata,
)
async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers.
Tools are wrapped with persistent-session logic so that consecutive
calls within the same thread reuse the same MCP session.
Returns:
List of LangChain tools from all enabled MCP servers.
"""
@@ -50,7 +210,7 @@ async def get_mcp_tools() -> list[BaseTool]:
existing_headers["Authorization"] = auth_header
servers_config[server_name]["headers"] = existing_headers
tool_interceptors = []
tool_interceptors: list[Any] = []
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
if oauth_interceptor is not None:
tool_interceptors.append(oauth_interceptor)
@@ -74,20 +234,42 @@ async def get_mcp_tools() -> list[BaseTool]:
elif interceptor is not None:
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
except Exception as e:
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
logger.warning(
f"Failed to load MCP interceptor {interceptor_path}: {e}",
exc_info=True,
)
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
client = MultiServerMCPClient(
servers_config,
tool_interceptors=tool_interceptors,
tool_name_prefix=True,
)
# Get all tools from all servers
# Get all tools from all servers (discovers tool definitions via
# temporary sessions the persistent-session wrapping is applied below).
tools = await client.get_tools()
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Patch tools to support sync invocation, as deerflow client streams synchronously
# Wrap each tool with persistent-session logic.
wrapped_tools: list[BaseTool] = []
for tool in tools:
tool_server: str | None = None
for name in servers_config:
if tool.name.startswith(f"{name}_"):
tool_server = name
break
if tool_server is not None:
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
else:
wrapped_tools.append(tool)
# Patch tools to support sync invocation, as deerflow client streams synchronously
for tool in wrapped_tools:
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
return tools
return wrapped_tools
except Exception as e:
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)