mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 07:01:03 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 1b88c38d80 |
@@ -134,9 +134,36 @@ def reset_mcp_tools_cache() -> None:
|
||||
"""Reset the MCP tools cache.
|
||||
|
||||
This is useful for testing or when you want to reload MCP tools.
|
||||
Also closes all persistent MCP sessions so they are recreated on
|
||||
the next tool load.
|
||||
"""
|
||||
global _mcp_tools_cache, _cache_initialized, _config_mtime
|
||||
_mcp_tools_cache = None
|
||||
_cache_initialized = False
|
||||
_config_mtime = None
|
||||
|
||||
# Close persistent sessions – they will be recreated by the next
|
||||
# get_mcp_tools() call with the (possibly updated) connection config.
|
||||
try:
|
||||
from deerflow.mcp.session_pool import get_session_pool
|
||||
|
||||
pool = get_session_pool()
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(asyncio.run, pool.close_all())
|
||||
future.result()
|
||||
else:
|
||||
loop.run_until_complete(pool.close_all())
|
||||
except RuntimeError:
|
||||
asyncio.run(pool.close_all())
|
||||
except Exception:
|
||||
logger.debug("Could not close MCP session pool on cache reset", exc_info=True)
|
||||
|
||||
from deerflow.mcp.session_pool import reset_session_pool
|
||||
|
||||
reset_session_pool()
|
||||
logger.info("MCP tools cache reset")
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
"""Persistent MCP session pool for stateful tool calls.
|
||||
|
||||
When MCP tools are loaded via langchain-mcp-adapters with ``session=None``,
|
||||
each tool call creates a new MCP session. For stateful servers like Playwright,
|
||||
this means browser state (opened pages, filled forms) is lost between calls.
|
||||
|
||||
This module provides a session pool that maintains persistent MCP sessions,
|
||||
scoped by ``(server_name, scope_key)`` — typically scope_key is the thread_id —
|
||||
so that consecutive tool calls share the same session and server-side state.
|
||||
Sessions are evicted in LRU order when the pool reaches capacity.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from mcp import ClientSession
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MCPSessionPool:
|
||||
"""Manages persistent MCP sessions scoped by ``(server_name, scope_key)``."""
|
||||
|
||||
MAX_SESSIONS = 256
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._entries: OrderedDict[
|
||||
tuple[str, str],
|
||||
tuple[ClientSession, asyncio.AbstractEventLoop],
|
||||
] = OrderedDict()
|
||||
self._context_managers: dict[tuple[str, str], Any] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def get_session(
|
||||
self,
|
||||
server_name: str,
|
||||
scope_key: str,
|
||||
connection: dict[str, Any],
|
||||
) -> ClientSession:
|
||||
"""Get or create a persistent MCP session.
|
||||
|
||||
If an existing session was created in a different event loop (e.g.
|
||||
the sync-wrapper path), it is closed and replaced with a fresh one
|
||||
in the current loop.
|
||||
|
||||
Args:
|
||||
server_name: MCP server name.
|
||||
scope_key: Isolation key (typically thread_id).
|
||||
connection: Connection configuration for ``create_session``.
|
||||
|
||||
Returns:
|
||||
An initialized ``ClientSession``.
|
||||
"""
|
||||
key = (server_name, scope_key)
|
||||
current_loop = asyncio.get_running_loop()
|
||||
|
||||
async with self._lock:
|
||||
if key in self._entries:
|
||||
session, loop = self._entries[key]
|
||||
if loop is current_loop:
|
||||
self._entries.move_to_end(key)
|
||||
return session
|
||||
# Session belongs to a different event loop – close it.
|
||||
await self._close_session(key)
|
||||
|
||||
# Evict oldest entries when at capacity.
|
||||
while len(self._entries) >= self.MAX_SESSIONS:
|
||||
oldest_key = next(iter(self._entries))
|
||||
await self._close_session(oldest_key)
|
||||
|
||||
from langchain_mcp_adapters.sessions import create_session
|
||||
|
||||
cm = create_session(connection)
|
||||
session = await cm.__aenter__()
|
||||
await session.initialize()
|
||||
self._entries[key] = (session, current_loop)
|
||||
self._context_managers[key] = cm
|
||||
logger.info("Created persistent MCP session for %s/%s", server_name, scope_key)
|
||||
return session
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Cleanup helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _close_session(self, key: tuple[str, str]) -> None:
|
||||
cm = self._context_managers.pop(key, None)
|
||||
self._entries.pop(key, None)
|
||||
if cm is not None:
|
||||
try:
|
||||
await cm.__aexit__(None, None, None)
|
||||
except Exception:
|
||||
logger.warning("Error closing MCP session %s", key, exc_info=True)
|
||||
|
||||
async def close_scope(self, scope_key: str) -> None:
|
||||
"""Close all sessions for a given scope (e.g. thread_id)."""
|
||||
async with self._lock:
|
||||
keys_to_close = [k for k in self._entries if k[1] == scope_key]
|
||||
for key in keys_to_close:
|
||||
await self._close_session(key)
|
||||
|
||||
async def close_server(self, server_name: str) -> None:
|
||||
"""Close all sessions for a given server."""
|
||||
async with self._lock:
|
||||
keys_to_close = [k for k in self._entries if k[0] == server_name]
|
||||
for key in keys_to_close:
|
||||
await self._close_session(key)
|
||||
|
||||
async def close_all(self) -> None:
|
||||
"""Close every managed session."""
|
||||
async with self._lock:
|
||||
for key in list(self._context_managers.keys()):
|
||||
await self._close_session(key)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Module-level singleton
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_pool: MCPSessionPool | None = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_session_pool() -> MCPSessionPool:
|
||||
"""Return the global session-pool singleton."""
|
||||
global _pool
|
||||
if _pool is None:
|
||||
with _pool_lock:
|
||||
if _pool is None:
|
||||
_pool = MCPSessionPool()
|
||||
return _pool
|
||||
|
||||
|
||||
def reset_session_pool() -> None:
|
||||
"""Reset the singleton (for tests)."""
|
||||
global _pool
|
||||
_pool = None
|
||||
@@ -1,21 +1,83 @@
|
||||
"""Load MCP tools using langchain-mcp-adapters."""
|
||||
"""Load MCP tools using langchain-mcp-adapters with persistent sessions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from langchain_core.tools import BaseTool
|
||||
from langchain_core.tools import BaseTool, InjectedToolArg, StructuredTool
|
||||
|
||||
from deerflow.config.extensions_config import ExtensionsConfig
|
||||
from deerflow.mcp.client import build_servers_config
|
||||
from deerflow.mcp.oauth import build_oauth_tool_interceptor, get_initial_oauth_headers
|
||||
from deerflow.mcp.session_pool import get_session_pool
|
||||
from deerflow.reflection import resolve_variable
|
||||
from deerflow.tools.sync import make_sync_tool_wrapper
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_thread_id(runtime: Any) -> str:
|
||||
"""Extract thread_id from the injected tool runtime."""
|
||||
if runtime is not None:
|
||||
ctx = getattr(runtime, "context", None) or {}
|
||||
tid = ctx.get("thread_id")
|
||||
if tid is not None:
|
||||
return str(tid)
|
||||
config = getattr(runtime, "config", None) or {}
|
||||
tid = config.get("configurable", {}).get("thread_id")
|
||||
if tid is not None:
|
||||
return str(tid)
|
||||
return "default"
|
||||
|
||||
|
||||
def _make_session_pool_tool(
|
||||
tool: BaseTool,
|
||||
server_name: str,
|
||||
connection: dict[str, Any],
|
||||
) -> BaseTool:
|
||||
"""Wrap an MCP tool so it reuses a persistent session from the pool.
|
||||
|
||||
Replaces the per-call session creation with pool-managed sessions scoped
|
||||
by ``(server_name, thread_id)``. This ensures stateful MCP servers (e.g.
|
||||
Playwright) keep their state across tool calls within the same thread.
|
||||
"""
|
||||
# Strip the server-name prefix to recover the original MCP tool name.
|
||||
original_name = tool.name
|
||||
prefix = f"{server_name}_"
|
||||
if original_name.startswith(prefix):
|
||||
original_name = original_name[len(prefix) :]
|
||||
|
||||
pool = get_session_pool()
|
||||
|
||||
async def call_with_persistent_session(
|
||||
runtime: Annotated[object | None, InjectedToolArg()] = None,
|
||||
**arguments: dict[str, Any],
|
||||
) -> Any:
|
||||
thread_id = _extract_thread_id(runtime)
|
||||
session = await pool.get_session(server_name, thread_id, connection)
|
||||
call_tool_result = await session.call_tool(original_name, arguments)
|
||||
|
||||
from langchain_mcp_adapters.tools import _convert_call_tool_result
|
||||
|
||||
return _convert_call_tool_result(call_tool_result)
|
||||
|
||||
return StructuredTool(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
args_schema=tool.args_schema,
|
||||
coroutine=call_with_persistent_session,
|
||||
response_format="content_and_artifact",
|
||||
metadata=tool.metadata,
|
||||
)
|
||||
|
||||
|
||||
async def get_mcp_tools() -> list[BaseTool]:
|
||||
"""Get all tools from enabled MCP servers.
|
||||
|
||||
Tools are wrapped with persistent-session logic so that consecutive
|
||||
calls within the same thread reuse the same MCP session.
|
||||
|
||||
Returns:
|
||||
List of LangChain tools from all enabled MCP servers.
|
||||
"""
|
||||
@@ -50,7 +112,7 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
existing_headers["Authorization"] = auth_header
|
||||
servers_config[server_name]["headers"] = existing_headers
|
||||
|
||||
tool_interceptors = []
|
||||
tool_interceptors: list[Any] = []
|
||||
oauth_interceptor = build_oauth_tool_interceptor(extensions_config)
|
||||
if oauth_interceptor is not None:
|
||||
tool_interceptors.append(oauth_interceptor)
|
||||
@@ -74,20 +136,42 @@ async def get_mcp_tools() -> list[BaseTool]:
|
||||
elif interceptor is not None:
|
||||
logger.warning(f"Builder {interceptor_path} returned non-callable {type(interceptor).__name__}; skipping")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load MCP interceptor {interceptor_path}: {e}", exc_info=True)
|
||||
logger.warning(
|
||||
f"Failed to load MCP interceptor {interceptor_path}: {e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
client = MultiServerMCPClient(servers_config, tool_interceptors=tool_interceptors, tool_name_prefix=True)
|
||||
client = MultiServerMCPClient(
|
||||
servers_config,
|
||||
tool_interceptors=tool_interceptors,
|
||||
tool_name_prefix=True,
|
||||
)
|
||||
|
||||
# Get all tools from all servers
|
||||
# Get all tools from all servers (discovers tool definitions via
|
||||
# temporary sessions – the persistent-session wrapping is applied below).
|
||||
tools = await client.get_tools()
|
||||
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
|
||||
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
# Wrap each tool with persistent-session logic.
|
||||
wrapped_tools: list[BaseTool] = []
|
||||
for tool in tools:
|
||||
tool_server: str | None = None
|
||||
for name in servers_config:
|
||||
if tool.name.startswith(f"{name}_"):
|
||||
tool_server = name
|
||||
break
|
||||
|
||||
if tool_server is not None:
|
||||
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server]))
|
||||
else:
|
||||
wrapped_tools.append(tool)
|
||||
|
||||
# Patch tools to support sync invocation, as deerflow client streams synchronously
|
||||
for tool in wrapped_tools:
|
||||
if getattr(tool, "func", None) is None and getattr(tool, "coroutine", None) is not None:
|
||||
tool.func = make_sync_tool_wrapper(tool.coroutine, tool.name)
|
||||
|
||||
return tools
|
||||
return wrapped_tools
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load MCP tools: {e}", exc_info=True)
|
||||
|
||||
@@ -0,0 +1,330 @@
|
||||
"""Tests for the MCP persistent-session pool."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_pool():
|
||||
reset_session_pool()
|
||||
yield
|
||||
reset_session_pool()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MCPSessionPool unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_creates_new():
|
||||
"""First call for a key creates a new session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert session is mock_session
|
||||
mock_session.initialize.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_session_reuses_existing():
|
||||
"""Second call for the same key returns the cached session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert s1 is s2
|
||||
# Only one session should have been created.
|
||||
assert mock_cm.__aenter__.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_different_scope_creates_different_session():
|
||||
"""Different scope keys get different sessions."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
sessions = [AsyncMock(), AsyncMock()]
|
||||
idx = 0
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.enter_count = 0
|
||||
|
||||
async def __aenter__(self):
|
||||
nonlocal idx
|
||||
s = sessions[idx]
|
||||
idx += 1
|
||||
self.enter_count += 1
|
||||
return s
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return False
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()):
|
||||
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
||||
s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert s1 is not s2
|
||||
assert s1 is sessions[0]
|
||||
assert s2 is sessions[1]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_lru_eviction():
|
||||
"""Oldest entries are evicted when the pool is full."""
|
||||
pool = MCPSessionPool()
|
||||
pool.MAX_SESSIONS = 2
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
# Pool is full (2). Adding t3 should evict t1.
|
||||
await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
assert cms[2].closed is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_scope():
|
||||
"""close_scope shuts down sessions for a specific scope key."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
await pool.close_scope("t1")
|
||||
|
||||
assert cms[0].closed is True
|
||||
assert cms[1].closed is False
|
||||
|
||||
# t2 session still exists.
|
||||
assert ("s", "t2") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_close_all():
|
||||
"""close_all shuts down every session."""
|
||||
pool = MCPSessionPool()
|
||||
|
||||
class CmFactory:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
|
||||
async def __aenter__(self):
|
||||
return AsyncMock()
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
self.closed = True
|
||||
return False
|
||||
|
||||
cms: list[CmFactory] = []
|
||||
|
||||
def make_cm(*a, **kw):
|
||||
cm = CmFactory()
|
||||
cms.append(cm)
|
||||
return cm
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
||||
await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []})
|
||||
await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
await pool.close_all()
|
||||
|
||||
assert all(cm.closed for cm in cms)
|
||||
assert len(pool._entries) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Singleton helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_session_pool_singleton():
|
||||
"""get_session_pool returns the same instance."""
|
||||
p1 = get_session_pool()
|
||||
p2 = get_session_pool()
|
||||
assert p1 is p2
|
||||
|
||||
|
||||
def test_reset_session_pool():
|
||||
"""reset_session_pool clears the singleton."""
|
||||
p1 = get_session_pool()
|
||||
reset_session_pool()
|
||||
p2 = get_session_pool()
|
||||
assert p1 is not p2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: _make_session_pool_tool uses the pool
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_wrapping():
|
||||
"""The wrapper tool delegates to a pool-managed session."""
|
||||
# Build a dummy StructuredTool (as returned by langchain-mcp-adapters).
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
url: str = Field(..., description="url")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="playwright_navigate",
|
||||
description="Navigate browser",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
connection = {"transport": "stdio", "command": "pw", "args": []}
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
|
||||
|
||||
# Simulate a tool call with a runtime context containing thread_id.
|
||||
mock_runtime = MagicMock()
|
||||
mock_runtime.context = {"thread_id": "thread-42"}
|
||||
mock_runtime.config = {}
|
||||
|
||||
await wrapped.coroutine(runtime=mock_runtime, url="https://example.com")
|
||||
|
||||
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_extracts_thread_id():
|
||||
"""Thread ID is extracted from runtime.config when not in context."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
mock_runtime = MagicMock()
|
||||
mock_runtime.context = {}
|
||||
mock_runtime.config = {"configurable": {"thread_id": "from-config"}}
|
||||
|
||||
await wrapped.coroutine(runtime=mock_runtime, x=1)
|
||||
|
||||
# Verify the session was created with the correct scope key.
|
||||
pool = get_session_pool()
|
||||
assert ("server", "from-config") in pool._entries
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_pool_tool_default_scope():
|
||||
"""When no thread_id is available, 'default' is used as scope key."""
|
||||
from langchain_core.tools import StructuredTool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from deerflow.mcp.tools import _make_session_pool_tool
|
||||
|
||||
class Args(BaseModel):
|
||||
x: int = Field(..., description="x")
|
||||
|
||||
original_tool = StructuredTool(
|
||||
name="server_tool",
|
||||
description="test",
|
||||
args_schema=Args,
|
||||
coroutine=AsyncMock(),
|
||||
response_format="content_and_artifact",
|
||||
)
|
||||
|
||||
mock_session = AsyncMock()
|
||||
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
||||
mock_cm = MagicMock()
|
||||
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
|
||||
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
||||
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
||||
|
||||
# No thread_id in runtime at all.
|
||||
await wrapped.coroutine(runtime=None, x=1)
|
||||
|
||||
pool = get_session_pool()
|
||||
assert ("server", "default") in pool._entries
|
||||
Reference in New Issue
Block a user