diff --git a/backend/packages/harness/deerflow/mcp/cache.py b/backend/packages/harness/deerflow/mcp/cache.py index c1121f59d..176eaa126 100644 --- a/backend/packages/harness/deerflow/mcp/cache.py +++ b/backend/packages/harness/deerflow/mcp/cache.py @@ -134,9 +134,36 @@ def reset_mcp_tools_cache() -> None: """Reset the MCP tools cache. This is useful for testing or when you want to reload MCP tools. + Also closes all persistent MCP sessions so they are recreated on + the next tool load. """ global _mcp_tools_cache, _cache_initialized, _config_mtime _mcp_tools_cache = None _cache_initialized = False _config_mtime = None + + # Close persistent sessions – they will be recreated by the next + # get_mcp_tools() call with the (possibly updated) connection config. + try: + from deerflow.mcp.session_pool import get_session_pool + + pool = get_session_pool() + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + import concurrent.futures + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, pool.close_all()) + future.result() + else: + loop.run_until_complete(pool.close_all()) + except RuntimeError: + asyncio.run(pool.close_all()) + except Exception: + logger.debug("Could not close MCP session pool on cache reset", exc_info=True) + + from deerflow.mcp.session_pool import reset_session_pool + + reset_session_pool() logger.info("MCP tools cache reset") diff --git a/backend/packages/harness/deerflow/mcp/session_pool.py b/backend/packages/harness/deerflow/mcp/session_pool.py new file mode 100644 index 000000000..f8f8c3fbb --- /dev/null +++ b/backend/packages/harness/deerflow/mcp/session_pool.py @@ -0,0 +1,141 @@ +"""Persistent MCP session pool for stateful tool calls. + +When MCP tools are loaded via langchain-mcp-adapters with ``session=None``, +each tool call creates a new MCP session. For stateful servers like Playwright, +this means browser state (opened pages, filled forms) is lost between calls. + +This module provides a session pool that maintains persistent MCP sessions, +scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id — +so that consecutive tool calls share the same session and server-side state. +Sessions are evicted in LRU order when the pool reaches capacity. +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from collections import OrderedDict +from typing import Any + +from mcp import ClientSession + +logger = logging.getLogger(__name__) + + +class MCPSessionPool: + """Manages persistent MCP sessions scoped by ``(server_name, scope_key)``.""" + + MAX_SESSIONS = 256 + + def __init__(self) -> None: + self._entries: OrderedDict[ + tuple[str, str], + tuple[ClientSession, asyncio.AbstractEventLoop], + ] = OrderedDict() + self._context_managers: dict[tuple[str, str], Any] = {} + self._lock = asyncio.Lock() + + async def get_session( + self, + server_name: str, + scope_key: str, + connection: dict[str, Any], + ) -> ClientSession: + """Get or create a persistent MCP session. + + If an existing session was created in a different event loop (e.g. + the sync-wrapper path), it is closed and replaced with a fresh one + in the current loop. + + Args: + server_name: MCP server name. + scope_key: Isolation key (typically thread_id). + connection: Connection configuration for ``create_session``. + + Returns: + An initialized ``ClientSession``. + """ + key = (server_name, scope_key) + current_loop = asyncio.get_running_loop() + + async with self._lock: + if key in self._entries: + session, loop = self._entries[key] + if loop is current_loop: + self._entries.move_to_end(key) + return session + # Session belongs to a different event loop – close it. + await self._close_session(key) + + # Evict oldest entries when at capacity. + while len(self._entries) >= self.MAX_SESSIONS: + oldest_key = next(iter(self._entries)) + await self._close_session(oldest_key) + + from langchain_mcp_adapters.sessions import create_session + + cm = create_session(connection) + session = await cm.__aenter__() + await session.initialize() + self._entries[key] = (session, current_loop) + self._context_managers[key] = cm + logger.info("Created persistent MCP session for %s/%s", server_name, scope_key) + return session + + # ------------------------------------------------------------------ + # Cleanup helpers + # ------------------------------------------------------------------ + + async def _close_session(self, key: tuple[str, str]) -> None: + cm = self._context_managers.pop(key, None) + self._entries.pop(key, None) + if cm is not None: + try: + await cm.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing MCP session %s", key, exc_info=True) + + async def close_scope(self, scope_key: str) -> None: + """Close all sessions for a given scope (e.g. thread_id).""" + async with self._lock: + keys_to_close = [k for k in self._entries if k[1] == scope_key] + for key in keys_to_close: + await self._close_session(key) + + async def close_server(self, server_name: str) -> None: + """Close all sessions for a given server.""" + async with self._lock: + keys_to_close = [k for k in self._entries if k[0] == server_name] + for key in keys_to_close: + await self._close_session(key) + + async def close_all(self) -> None: + """Close every managed session.""" + async with self._lock: + for key in list(self._context_managers.keys()): + await self._close_session(key) + + +# ------------------------------------------------------------------ +# Module-level singleton +# ------------------------------------------------------------------ + +_pool: MCPSessionPool | None = None +_pool_lock = threading.Lock() + + +def get_session_pool() -> MCPSessionPool: + """Return the global session-pool singleton.""" + global _pool + if _pool is None: + with _pool_lock: + if _pool is None: + _pool = MCPSessionPool() + return _pool + + +def reset_session_pool() -> None: + """Reset the singleton (for tests).""" + global _pool + _pool = None diff --git a/backend/packages/harness/deerflow/mcp/tools.py b/backend/packages/harness/deerflow/mcp/tools.py index d27641692..47d6ed99b 100644 --- a/backend/packages/harness/deerflow/mcp/tools.py +++ b/backend/packages/harness/deerflow/mcp/tools.py @@ -1,21 +1,83 @@ -"""Load MCP tools using langchain-mcp-adapters.""" +"""Load MCP tools using langchain-mcp-adapters with persistent sessions.""" + +from __future__ import annotations import logging +from typing import Annotated, Any -from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool, InjectedToolArg, StructuredTool from deerflow.config.extensions_config import ExtensionsConfig from deerflow.mcp.client import build_servers_config from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers +from deerflow.mcp.session_pool import get_session_pool from deerflow.reflection import resolve_variable from deerflow.tools.sync import make_sync_tool_wrapper logger = logging.getLogger(__name__) +def _extract_thread_id(runtime: Any) -> str: + """Extract thread_id from the injected tool runtime.""" + if runtime is not None: + ctx = getattr(runtime, "context", None) or {} + tid = ctx.get("thread_id") + if tid is not None: + return str(tid) + config = getattr(runtime, "config", None) or {} + tid = config.get("configurable", {}).get("thread_id") + if tid is not None: + return str(tid) + return "default" + + +def _make_session_pool_tool( + tool: BaseTool, + server_name: str, + connection: dict[str, Any], +) -> BaseTool: + """Wrap an MCP tool so it reuses a persistent session from the pool. + + Replaces the per-call session creation with pool-managed sessions scoped + by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g. + Playwright) keep their state across tool calls within the same thread. + """ + # Strip the server-name prefix to recover the original MCP tool name. + original_name = tool.name + prefix = f"{server_name}_" + if original_name.startswith(prefix): + original_name = original_name[len(prefix) :] + + pool = get_session_pool() + + async def call_with_persistent_session( + runtime: Annotated[object | None, InjectedToolArg()] = None, + **arguments: dict[str, Any], + ) -> Any: + thread_id = _extract_thread_id(runtime) + session = await pool.get_session(server_name, thread_id, connection) + call_tool_result = await session.call_tool(original_name, arguments) + + from langchain_mcp_adapters.tools import _convert_call_tool_result + + return _convert_call_tool_result(call_tool_result) + + return StructuredTool( + name=tool.name, + description=tool.description, + args_schema=tool.args_schema, + coroutine=call_with_persistent_session, + response_format="content_and_artifact", + metadata=tool.metadata, + ) + + async def get_mcp_tools() -> list[BaseTool]: """Get all tools from enabled MCP servers. + Tools are wrapped with persistent-session logic so that consecutive + calls within the same thread reuse the same MCP session. + Returns: List of LangChain tools from all enabled MCP servers. """ @@ -50,7 +112,7 @@ async def get_mcp_tools() -> list[BaseTool]: existing_headers["Authorization"] = auth_header servers_config[server_name]["headers"] = existing_headers - tool_interceptors = [] + tool_interceptors: list[Any] = [] oauth_interceptor = build_oauth_tool_interceptor(extensions_config) if oauth_interceptor is not None: tool_interceptors.append(oauth_interceptor) @@ -74,20 +136,42 @@ async def get_mcp_tools() -> list[BaseTool]: elif interceptor is not None: logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping") except Exception as e: - logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True) + logger.warning( + f"Failed to load MCP interceptor {interceptor_path}: {e}", + exc_info=True, + ) - client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True) + client = MultiServerMCPClient( + servers_config, + tool_interceptors=tool_interceptors, + tool_name_prefix=True, + ) - # Get all tools from all servers + # Get all tools from all servers (discovers tool definitions via + # temporary sessions – the persistent-session wrapping is applied below). tools = await client.get_tools() logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers") - # Patch tools to support sync invocation, as deerflow client streams synchronously + # Wrap each tool with persistent-session logic. + wrapped_tools: list[BaseTool] = [] for tool in tools: + tool_server: str | None = None + for name in servers_config: + if tool.name.startswith(f"{name}_"): + tool_server = name + break + + if tool_server is not None: + wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server])) + else: + wrapped_tools.append(tool) + + # Patch tools to support sync invocation, as deerflow client streams synchronously + for tool in wrapped_tools: if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None: tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name) - return tools + return wrapped_tools except Exception as e: logger.error(f"Failed to load MCP tools: {e}", exc_info=True) diff --git a/backend/tests/test_mcp_session_pool.py b/backend/tests/test_mcp_session_pool.py new file mode 100644 index 000000000..61e54c084 --- /dev/null +++ b/backend/tests/test_mcp_session_pool.py @@ -0,0 +1,330 @@ +"""Tests for the MCP persistent-session pool.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool + + +@pytest.fixture(autouse=True) +def _reset_pool(): + reset_session_pool() + yield + reset_session_pool() + + +# --------------------------------------------------------------------------- +# MCPSessionPool unit tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_session_creates_new(): + """First call for a key creates a new session.""" + pool = MCPSessionPool() + + mock_session = AsyncMock() + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + + assert session is mock_session + mock_session.initialize.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_session_reuses_existing(): + """Second call for the same key returns the cached session.""" + pool = MCPSessionPool() + + mock_session = AsyncMock() + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + + assert s1 is s2 + # Only one session should have been created. + assert mock_cm.__aenter__.await_count == 1 + + +@pytest.mark.asyncio +async def test_different_scope_creates_different_session(): + """Different scope keys get different sessions.""" + pool = MCPSessionPool() + + sessions = [AsyncMock(), AsyncMock()] + idx = 0 + + class CmFactory: + def __init__(self): + self.enter_count = 0 + + async def __aenter__(self): + nonlocal idx + s = sessions[idx] + idx += 1 + self.enter_count += 1 + return s + + async def __aexit__(self, *args): + return False + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()): + s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []}) + + assert s1 is not s2 + assert s1 is sessions[0] + assert s2 is sessions[1] + + +@pytest.mark.asyncio +async def test_lru_eviction(): + """Oldest entries are evicted when the pool is full.""" + pool = MCPSessionPool() + pool.MAX_SESSIONS = 2 + + class CmFactory: + def __init__(self): + self.closed = False + + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, *args): + self.closed = True + return False + + cms: list[CmFactory] = [] + + def make_cm(*a, **kw): + cm = CmFactory() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []}) + # Pool is full (2). Adding t3 should evict t1. + await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []}) + + assert cms[0].closed is True + assert cms[1].closed is False + assert cms[2].closed is False + + +@pytest.mark.asyncio +async def test_close_scope(): + """close_scope shuts down sessions for a specific scope key.""" + pool = MCPSessionPool() + + class CmFactory: + def __init__(self): + self.closed = False + + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, *args): + self.closed = True + return False + + cms: list[CmFactory] = [] + + def make_cm(*a, **kw): + cm = CmFactory() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []}) + + await pool.close_scope("t1") + + assert cms[0].closed is True + assert cms[1].closed is False + + # t2 session still exists. + assert ("s", "t2") in pool._entries + + +@pytest.mark.asyncio +async def test_close_all(): + """close_all shuts down every session.""" + pool = MCPSessionPool() + + class CmFactory: + def __init__(self): + self.closed = False + + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, *args): + self.closed = True + return False + + cms: list[CmFactory] = [] + + def make_cm(*a, **kw): + cm = CmFactory() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []}) + await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []}) + + await pool.close_all() + + assert all(cm.closed for cm in cms) + assert len(pool._entries) == 0 + + +# --------------------------------------------------------------------------- +# Singleton helpers +# --------------------------------------------------------------------------- + + +def test_get_session_pool_singleton(): + """get_session_pool returns the same instance.""" + p1 = get_session_pool() + p2 = get_session_pool() + assert p1 is p2 + + +def test_reset_session_pool(): + """reset_session_pool clears the singleton.""" + p1 = get_session_pool() + reset_session_pool() + p2 = get_session_pool() + assert p1 is not p2 + + +# --------------------------------------------------------------------------- +# Integration: _make_session_pool_tool uses the pool +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_session_pool_tool_wrapping(): + """The wrapper tool delegates to a pool-managed session.""" + # Build a dummy StructuredTool (as returned by langchain-mcp-adapters). + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + url: str = Field(..., description="url") + + original_tool = StructuredTool( + name="playwright_navigate", + description="Navigate browser", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + connection = {"transport": "stdio", "command": "pw", "args": []} + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "playwright", connection) + + # Simulate a tool call with a runtime context containing thread_id. + mock_runtime = MagicMock() + mock_runtime.context = {"thread_id": "thread-42"} + mock_runtime.config = {} + + await wrapped.coroutine(runtime=mock_runtime, url="https://example.com") + + mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"}) + + +@pytest.mark.asyncio +async def test_session_pool_tool_extracts_thread_id(): + """Thread ID is extracted from runtime.config when not in context.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="server_tool", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []}) + + mock_runtime = MagicMock() + mock_runtime.context = {} + mock_runtime.config = {"configurable": {"thread_id": "from-config"}} + + await wrapped.coroutine(runtime=mock_runtime, x=1) + + # Verify the session was created with the correct scope key. + pool = get_session_pool() + assert ("server", "from-config") in pool._entries + + +@pytest.mark.asyncio +async def test_session_pool_tool_default_scope(): + """When no thread_id is available, 'default' is used as scope key.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="server_tool", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []}) + + # No thread_id in runtime at all. + await wrapped.coroutine(runtime=None, x=1) + + pool = get_session_pool() + assert ("server", "default") in pool._entries