diff --git a/backend/packages/harness/deerflow/mcp/cache.py b/backend/packages/harness/deerflow/mcp/cache.py index c1121f59d..f04fe0054 100644 --- a/backend/packages/harness/deerflow/mcp/cache.py +++ b/backend/packages/harness/deerflow/mcp/cache.py @@ -134,9 +134,25 @@ def reset_mcp_tools_cache() -> None: """Reset the MCP tools cache. This is useful for testing or when you want to reload MCP tools. + Also closes all persistent MCP sessions so they are recreated on + the next tool load. """ global _mcp_tools_cache, _cache_initialized, _config_mtime _mcp_tools_cache = None _cache_initialized = False _config_mtime = None + + # Close persistent sessions – they will be recreated by the next + # get_mcp_tools() call with the (possibly updated) connection config. + try: + from deerflow.mcp.session_pool import get_session_pool + + pool = get_session_pool() + pool.close_all_sync() + except Exception: + logger.debug("Could not close MCP session pool on cache reset", exc_info=True) + + from deerflow.mcp.session_pool import reset_session_pool + + reset_session_pool() logger.info("MCP tools cache reset") diff --git a/backend/packages/harness/deerflow/mcp/session_pool.py b/backend/packages/harness/deerflow/mcp/session_pool.py new file mode 100644 index 000000000..8450cac8e --- /dev/null +++ b/backend/packages/harness/deerflow/mcp/session_pool.py @@ -0,0 +1,198 @@ +"""Persistent MCP session pool for stateful tool calls. + +When MCP tools are loaded via langchain-mcp-adapters with ``session=None``, +each tool call creates a new MCP session. For stateful servers like Playwright, +this means browser state (opened pages, filled forms) is lost between calls. + +This module provides a session pool that maintains persistent MCP sessions, +scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id — +so that consecutive tool calls share the same session and server-side state. +Sessions are evicted in LRU order when the pool reaches capacity. +""" + +from __future__ import annotations + +import asyncio +import logging +import threading +from collections import OrderedDict +from typing import Any + +from mcp import ClientSession + +logger = logging.getLogger(__name__) + + +class MCPSessionPool: + """Manages persistent MCP sessions scoped by ``(server_name, scope_key)``.""" + + MAX_SESSIONS = 256 + SESSION_CLOSE_TIMEOUT = 5.0 # seconds to wait when closing a session via run_coroutine_threadsafe + + def __init__(self) -> None: + self._entries: OrderedDict[ + tuple[str, str], + tuple[ClientSession, asyncio.AbstractEventLoop], + ] = OrderedDict() + self._context_managers: dict[tuple[str, str], Any] = {} + # threading.Lock is not bound to any event loop, so it is safe to + # acquire from both async paths and sync/worker-thread paths. + self._lock = threading.Lock() + + async def get_session( + self, + server_name: str, + scope_key: str, + connection: dict[str, Any], + ) -> ClientSession: + """Get or create a persistent MCP session. + + If an existing session was created in a different event loop (e.g. + the sync-wrapper path), it is closed and replaced with a fresh one + in the current loop. + + Args: + server_name: MCP server name. + scope_key: Isolation key (typically thread_id). + connection: Connection configuration for ``create_session``. + + Returns: + An initialized ``ClientSession``. + """ + key = (server_name, scope_key) + current_loop = asyncio.get_running_loop() + + # Phase 1: inspect/mutate the registry under the thread lock (no awaits). + cms_to_close: list[tuple[tuple[str, str], Any]] = [] + with self._lock: + if key in self._entries: + session, loop = self._entries[key] + if loop is current_loop: + self._entries.move_to_end(key) + return session + # Session belongs to a different event loop – evict it. + cm = self._context_managers.pop(key, None) + self._entries.pop(key) + if cm is not None: + cms_to_close.append((key, cm)) + + # Evict LRU entries when at capacity. + while len(self._entries) >= self.MAX_SESSIONS: + oldest_key = next(iter(self._entries)) + cm = self._context_managers.pop(oldest_key, None) + self._entries.pop(oldest_key) + if cm is not None: + cms_to_close.append((oldest_key, cm)) + + # Phase 2: async cleanup outside the lock so we never await while holding it. + for close_key, cm in cms_to_close: + try: + await cm.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing MCP session %s", close_key, exc_info=True) + + from langchain_mcp_adapters.sessions import create_session + + cm = create_session(connection) + session = await cm.__aenter__() + await session.initialize() + + # Phase 3: register the new session under the lock. + with self._lock: + self._entries[key] = (session, current_loop) + self._context_managers[key] = cm + logger.info("Created persistent MCP session for %s/%s", server_name, scope_key) + return session + + # ------------------------------------------------------------------ + # Cleanup helpers + # ------------------------------------------------------------------ + + async def _close_cm(self, key: tuple[str, str], cm: Any) -> None: + """Close a single context manager (must be called WITHOUT the lock).""" + try: + await cm.__aexit__(None, None, None) + except Exception: + logger.warning("Error closing MCP session %s", key, exc_info=True) + + async def close_scope(self, scope_key: str) -> None: + """Close all sessions for a given scope (e.g. thread_id).""" + with self._lock: + keys = [k for k in self._entries if k[1] == scope_key] + cms = [(k, self._context_managers.pop(k, None)) for k in keys] + for k in keys: + self._entries.pop(k, None) + for key, cm in cms: + if cm is not None: + await self._close_cm(key, cm) + + async def close_server(self, server_name: str) -> None: + """Close all sessions for a given server.""" + with self._lock: + keys = [k for k in self._entries if k[0] == server_name] + cms = [(k, self._context_managers.pop(k, None)) for k in keys] + for k in keys: + self._entries.pop(k, None) + for key, cm in cms: + if cm is not None: + await self._close_cm(key, cm) + + async def close_all(self) -> None: + """Close every managed session.""" + with self._lock: + cms = list(self._context_managers.items()) + self._context_managers.clear() + self._entries.clear() + for key, cm in cms: + await self._close_cm(key, cm) + + def close_all_sync(self) -> None: + """Close all sessions using their owning event loops (synchronous). + + Each session is closed on the loop it was created in, avoiding + cross-loop resource leaks. Safe to call from any thread without an + active event loop. + """ + with self._lock: + entries = list(self._entries.items()) + cms = dict(self._context_managers) + self._entries.clear() + self._context_managers.clear() + + for key, (_, loop) in entries: + cm = cms.get(key) + if cm is None or loop.is_closed(): + continue + try: + if loop.is_running(): + # Schedule on the owning loop from this (different) thread. + future = asyncio.run_coroutine_threadsafe(cm.__aexit__(None, None, None), loop) + future.result(timeout=self.SESSION_CLOSE_TIMEOUT) + else: + loop.run_until_complete(cm.__aexit__(None, None, None)) + except Exception: + logger.debug("Error closing MCP session %s during sync close", key, exc_info=True) + + +# ------------------------------------------------------------------ +# Module-level singleton +# ------------------------------------------------------------------ + +_pool: MCPSessionPool | None = None +_pool_lock = threading.Lock() + + +def get_session_pool() -> MCPSessionPool: + """Return the global session-pool singleton.""" + global _pool + if _pool is None: + with _pool_lock: + if _pool is None: + _pool = MCPSessionPool() + return _pool + + +def reset_session_pool() -> None: + """Reset the singleton (for tests).""" + global _pool + _pool = None diff --git a/backend/packages/harness/deerflow/mcp/tools.py b/backend/packages/harness/deerflow/mcp/tools.py index d27641692..d08e7efd6 100644 --- a/backend/packages/harness/deerflow/mcp/tools.py +++ b/backend/packages/harness/deerflow/mcp/tools.py @@ -1,21 +1,181 @@ -"""Load MCP tools using langchain-mcp-adapters.""" +"""Load MCP tools using langchain-mcp-adapters with persistent sessions.""" + +from __future__ import annotations import logging +from typing import Any -from langchain_core.tools import BaseTool +from langchain_core.tools import BaseTool, StructuredTool +from langgraph.config import get_config from deerflow.config.extensions_config import ExtensionsConfig from deerflow.mcp.client import build_servers_config from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers +from deerflow.mcp.session_pool import get_session_pool from deerflow.reflection import resolve_variable from deerflow.tools.sync import make_sync_tool_wrapper +from deerflow.tools.types import Runtime logger = logging.getLogger(__name__) +def _extract_thread_id(runtime: Runtime | None) -> str: + """Extract thread_id from the injected tool runtime or LangGraph config.""" + if runtime is not None: + tid = runtime.context.get("thread_id") if runtime.context else None + if tid is not None: + return str(tid) + config = runtime.config or {} + tid = config.get("configurable", {}).get("thread_id") + if tid is not None: + return str(tid) + + try: + tid = get_config().get("configurable", {}).get("thread_id") + return str(tid) if tid is not None else "default" + except RuntimeError: + return "default" + + +def _convert_call_tool_result(call_tool_result: Any) -> Any: + """Convert an MCP CallToolResult to the LangChain ``content_and_artifact`` format. + + Implements the same conversion logic as the adapter without relying on + the private ``langchain_mcp_adapters.tools._convert_call_tool_result`` symbol. + """ + from langchain_core.messages import ToolMessage + from langchain_core.messages.content import create_file_block, create_image_block, create_text_block + from langchain_core.tools import ToolException + from mcp.types import EmbeddedResource, ImageContent, ResourceLink, TextContent, TextResourceContents + + # Pass ToolMessage through directly (interceptor short-circuit). + if isinstance(call_tool_result, ToolMessage): + return call_tool_result, None + + # Pass LangGraph Command through directly when langgraph is installed. + try: + from langgraph.types import Command + + if isinstance(call_tool_result, Command): + return call_tool_result, None + except ImportError: + # langgraph is optional; if unavailable, continue with standard MCP content conversion. + pass + + # Convert MCP content blocks to LangChain content blocks. + lc_content = [] + for item in call_tool_result.content: + if isinstance(item, TextContent): + lc_content.append(create_text_block(text=item.text)) + elif isinstance(item, ImageContent): + lc_content.append(create_image_block(base64=item.data, mime_type=item.mimeType)) + elif isinstance(item, ResourceLink): + mime = item.mimeType or None + if mime and mime.startswith("image/"): + lc_content.append(create_image_block(url=str(item.uri), mime_type=mime)) + else: + lc_content.append(create_file_block(url=str(item.uri), mime_type=mime)) + elif isinstance(item, EmbeddedResource): + from mcp.types import BlobResourceContents + + res = item.resource + if isinstance(res, TextResourceContents): + lc_content.append(create_text_block(text=res.text)) + elif isinstance(res, BlobResourceContents): + mime = res.mimeType or None + if mime and mime.startswith("image/"): + lc_content.append(create_image_block(base64=res.blob, mime_type=mime)) + else: + lc_content.append(create_file_block(base64=res.blob, mime_type=mime)) + else: + lc_content.append(create_text_block(text=str(res))) + else: + lc_content.append(create_text_block(text=str(item))) + + if call_tool_result.isError: + error_parts = [item["text"] for item in lc_content if isinstance(item, dict) and item.get("type") == "text"] + raise ToolException("\n".join(error_parts) if error_parts else str(lc_content)) + + artifact = None + if call_tool_result.structuredContent is not None: + artifact = {"structured_content": call_tool_result.structuredContent} + + return lc_content, artifact + + +def _make_session_pool_tool( + tool: BaseTool, + server_name: str, + connection: dict[str, Any], + tool_interceptors: list[Any] | None = None, +) -> BaseTool: + """Wrap an MCP tool so it reuses a persistent session from the pool. + + Replaces the per-call session creation with pool-managed sessions scoped + by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g. + Playwright) keep their state across tool calls within the same thread. + + The configured ``tool_interceptors`` (OAuth, custom) are preserved and + applied on every call before invoking the pooled session. + """ + # Strip the server-name prefix to recover the original MCP tool name. + original_name = tool.name + prefix = f"{server_name}_" + if original_name.startswith(prefix): + original_name = original_name[len(prefix) :] + + pool = get_session_pool() + + async def call_with_persistent_session( + runtime: Runtime | None = None, + **arguments: Any, + ) -> Any: + thread_id = _extract_thread_id(runtime) + session = await pool.get_session(server_name, thread_id, connection) + + if tool_interceptors: + from langchain_mcp_adapters.interceptors import MCPToolCallRequest + + async def base_handler(request: MCPToolCallRequest) -> Any: + return await session.call_tool(request.name, request.args) + + handler = base_handler + for interceptor in reversed(tool_interceptors): + outer = handler + + async def wrapped(req: Any, _i: Any = interceptor, _h: Any = outer) -> Any: + return await _i(req, _h) + + handler = wrapped + + request = MCPToolCallRequest( + name=original_name, + args=arguments, + server_name=server_name, + runtime=runtime, + ) + call_tool_result = await handler(request) + else: + call_tool_result = await session.call_tool(original_name, arguments) + + return _convert_call_tool_result(call_tool_result) + + return StructuredTool( + name=tool.name, + description=tool.description, + args_schema=tool.args_schema, + coroutine=call_with_persistent_session, + response_format="content_and_artifact", + metadata=tool.metadata, + ) + + async def get_mcp_tools() -> list[BaseTool]: """Get all tools from enabled MCP servers. + Tools are wrapped with persistent-session logic so that consecutive + calls within the same thread reuse the same MCP session. + Returns: List of LangChain tools from all enabled MCP servers. """ @@ -50,7 +210,7 @@ async def get_mcp_tools() -> list[BaseTool]: existing_headers["Authorization"] = auth_header servers_config[server_name]["headers"] = existing_headers - tool_interceptors = [] + tool_interceptors: list[Any] = [] oauth_interceptor = build_oauth_tool_interceptor(extensions_config) if oauth_interceptor is not None: tool_interceptors.append(oauth_interceptor) @@ -74,20 +234,42 @@ async def get_mcp_tools() -> list[BaseTool]: elif interceptor is not None: logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping") except Exception as e: - logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True) + logger.warning( + f"Failed to load MCP interceptor {interceptor_path}: {e}", + exc_info=True, + ) - client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True) + client = MultiServerMCPClient( + servers_config, + tool_interceptors=tool_interceptors, + tool_name_prefix=True, + ) - # Get all tools from all servers + # Get all tools from all servers (discovers tool definitions via + # temporary sessions – the persistent-session wrapping is applied below). tools = await client.get_tools() logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers") - # Patch tools to support sync invocation, as deerflow client streams synchronously + # Wrap each tool with persistent-session logic. + wrapped_tools: list[BaseTool] = [] for tool in tools: + tool_server: str | None = None + for name in servers_config: + if tool.name.startswith(f"{name}_"): + tool_server = name + break + + if tool_server is not None: + wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors)) + else: + wrapped_tools.append(tool) + + # Patch tools to support sync invocation, as deerflow client streams synchronously + for tool in wrapped_tools: if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None: tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name) - return tools + return wrapped_tools except Exception as e: logger.error(f"Failed to load MCP tools: {e}", exc_info=True) diff --git a/backend/tests/test_mcp_session_pool.py b/backend/tests/test_mcp_session_pool.py new file mode 100644 index 000000000..822ad2e81 --- /dev/null +++ b/backend/tests/test_mcp_session_pool.py @@ -0,0 +1,409 @@ +"""Tests for the MCP persistent-session pool.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool + + +@pytest.fixture(autouse=True) +def _reset_pool(): + reset_session_pool() + yield + reset_session_pool() + + +# --------------------------------------------------------------------------- +# MCPSessionPool unit tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_session_creates_new(): + """First call for a key creates a new session.""" + pool = MCPSessionPool() + + mock_session = AsyncMock() + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + + assert session is mock_session + mock_session.initialize.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_session_reuses_existing(): + """Second call for the same key returns the cached session.""" + pool = MCPSessionPool() + + mock_session = AsyncMock() + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + + assert s1 is s2 + # Only one session should have been created. + assert mock_cm.__aenter__.await_count == 1 + + +@pytest.mark.asyncio +async def test_different_scope_creates_different_session(): + """Different scope keys get different sessions.""" + pool = MCPSessionPool() + + sessions = [AsyncMock(), AsyncMock()] + idx = 0 + + class CmFactory: + def __init__(self): + self.enter_count = 0 + + async def __aenter__(self): + nonlocal idx + s = sessions[idx] + idx += 1 + self.enter_count += 1 + return s + + async def __aexit__(self, *args): + return False + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()): + s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []}) + s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []}) + + assert s1 is not s2 + assert s1 is sessions[0] + assert s2 is sessions[1] + + +@pytest.mark.asyncio +async def test_lru_eviction(): + """Oldest entries are evicted when the pool is full.""" + pool = MCPSessionPool() + pool.MAX_SESSIONS = 2 + + class CmFactory: + def __init__(self): + self.closed = False + + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, *args): + self.closed = True + return False + + cms: list[CmFactory] = [] + + def make_cm(*a, **kw): + cm = CmFactory() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []}) + # Pool is full (2). Adding t3 should evict t1. + await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []}) + + assert cms[0].closed is True + assert cms[1].closed is False + assert cms[2].closed is False + + +@pytest.mark.asyncio +async def test_close_scope(): + """close_scope shuts down sessions for a specific scope key.""" + pool = MCPSessionPool() + + class CmFactory: + def __init__(self): + self.closed = False + + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, *args): + self.closed = True + return False + + cms: list[CmFactory] = [] + + def make_cm(*a, **kw): + cm = CmFactory() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []}) + await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []}) + + await pool.close_scope("t1") + + assert cms[0].closed is True + assert cms[1].closed is False + + # t2 session still exists. + assert ("s", "t2") in pool._entries + + +@pytest.mark.asyncio +async def test_close_all(): + """close_all shuts down every session.""" + pool = MCPSessionPool() + + class CmFactory: + def __init__(self): + self.closed = False + + async def __aenter__(self): + return AsyncMock() + + async def __aexit__(self, *args): + self.closed = True + return False + + cms: list[CmFactory] = [] + + def make_cm(*a, **kw): + cm = CmFactory() + cms.append(cm) + return cm + + with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm): + await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []}) + await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []}) + + await pool.close_all() + + assert all(cm.closed for cm in cms) + assert len(pool._entries) == 0 + + +# --------------------------------------------------------------------------- +# Singleton helpers +# --------------------------------------------------------------------------- + + +def test_get_session_pool_singleton(): + """get_session_pool returns the same instance.""" + p1 = get_session_pool() + p2 = get_session_pool() + assert p1 is p2 + + +def test_reset_session_pool(): + """reset_session_pool clears the singleton.""" + p1 = get_session_pool() + reset_session_pool() + p2 = get_session_pool() + assert p1 is not p2 + + +# --------------------------------------------------------------------------- +# Integration: _make_session_pool_tool uses the pool +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_session_pool_tool_wrapping(): + """The wrapper tool delegates to a pool-managed session.""" + # Build a dummy StructuredTool (as returned by langchain-mcp-adapters). + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + url: str = Field(..., description="url") + + original_tool = StructuredTool( + name="playwright_navigate", + description="Navigate browser", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + connection = {"transport": "stdio", "command": "pw", "args": []} + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "playwright", connection) + + # Simulate a tool call with a runtime context containing thread_id. + mock_runtime = MagicMock() + mock_runtime.context = {"thread_id": "thread-42"} + mock_runtime.config = {} + + await wrapped.coroutine(runtime=mock_runtime, url="https://example.com") + + mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"}) + + +@pytest.mark.asyncio +async def test_session_pool_tool_extracts_thread_id(): + """Thread ID is extracted from runtime.config when not in context.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="server_tool", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []}) + + mock_runtime = MagicMock() + mock_runtime.context = {} + mock_runtime.config = {"configurable": {"thread_id": "from-config"}} + + await wrapped.coroutine(runtime=mock_runtime, x=1) + + # Verify the session was created with the correct scope key. + pool = get_session_pool() + assert ("server", "from-config") in pool._entries + + +@pytest.mark.asyncio +async def test_session_pool_tool_default_scope(): + """When no thread_id is available, 'default' is used as scope key.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="server_tool", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []}) + + # No thread_id in runtime at all. + await wrapped.coroutine(runtime=None, x=1) + + pool = get_session_pool() + assert ("server", "default") in pool._entries + + +@pytest.mark.asyncio +async def test_session_pool_tool_get_config_fallback(): + """When runtime is None, get_config() provides thread_id as fallback.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + + class Args(BaseModel): + x: int = Field(..., description="x") + + original_tool = StructuredTool( + name="server_tool", + description="test", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + fake_config = {"configurable": {"thread_id": "from-langgraph-config"}} + + with ( + patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm), + patch("deerflow.mcp.tools.get_config", return_value=fake_config), + ): + wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []}) + + # runtime=None — get_config() fallback should provide thread_id + await wrapped.coroutine(runtime=None, x=1) + + pool = get_session_pool() + assert ("server", "from-langgraph-config") in pool._entries + + +def test_session_pool_tool_sync_wrapper_path_is_safe(): + """Sync wrapper (tool.func) invocation doesn't crash on cross-loop access.""" + from langchain_core.tools import StructuredTool + from pydantic import BaseModel, Field + + from deerflow.mcp.tools import _make_session_pool_tool + from deerflow.tools.sync import make_sync_tool_wrapper + + class Args(BaseModel): + url: str = Field(..., description="url") + + original_tool = StructuredTool( + name="playwright_navigate", + description="Navigate browser", + args_schema=Args, + coroutine=AsyncMock(), + response_format="content_and_artifact", + ) + + mock_session = AsyncMock() + mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None)) + mock_cm = MagicMock() + mock_cm.__aenter__ = AsyncMock(return_value=mock_session) + mock_cm.__aexit__ = AsyncMock(return_value=False) + + connection = {"transport": "stdio", "command": "pw", "args": []} + + with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm): + wrapped = _make_session_pool_tool(original_tool, "playwright", connection) + # Attach the sync wrapper exactly as get_mcp_tools() does. + wrapped.func = make_sync_tool_wrapper(wrapped.coroutine, wrapped.name) + + # Call via the sync path (asyncio.run in a worker thread). + # runtime is not supplied so _extract_thread_id falls back to "default". + wrapped.func(url="https://example.com") + + mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})