mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-20 07:01:03 +00:00
1b88c38d80
MCP tools loaded via langchain-mcp-adapters created a new session on every call, causing stateful servers like Playwright to lose browser state (pages, forms) between consecutive tool invocations within the same thread. Add MCPSessionPool that maintains persistent sessions scoped by (server_name, thread_id). Tool calls within the same thread now reuse the same MCP session, preserving server-side state. Sessions are evicted in LRU order (max 256) and cleaned up on cache invalidation. Fixes #3054
331 lines
10 KiB
Python
331 lines
10 KiB
Python
"""Tests for the MCP persistent-session pool."""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from deerflow.mcp.session_pool import MCPSessionPool, get_session_pool, reset_session_pool
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _reset_pool():
|
|
reset_session_pool()
|
|
yield
|
|
reset_session_pool()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# MCPSessionPool unit tests
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session_creates_new():
|
|
"""First call for a key creates a new session."""
|
|
pool = MCPSessionPool()
|
|
|
|
mock_session = AsyncMock()
|
|
mock_cm = MagicMock()
|
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
|
session = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
|
|
|
assert session is mock_session
|
|
mock_session.initialize.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_session_reuses_existing():
|
|
"""Second call for the same key returns the cached session."""
|
|
pool = MCPSessionPool()
|
|
|
|
mock_session = AsyncMock()
|
|
mock_cm = MagicMock()
|
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
|
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
|
s2 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
|
|
|
assert s1 is s2
|
|
# Only one session should have been created.
|
|
assert mock_cm.__aenter__.await_count == 1
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_different_scope_creates_different_session():
|
|
"""Different scope keys get different sessions."""
|
|
pool = MCPSessionPool()
|
|
|
|
sessions = [AsyncMock(), AsyncMock()]
|
|
idx = 0
|
|
|
|
class CmFactory:
|
|
def __init__(self):
|
|
self.enter_count = 0
|
|
|
|
async def __aenter__(self):
|
|
nonlocal idx
|
|
s = sessions[idx]
|
|
idx += 1
|
|
self.enter_count += 1
|
|
return s
|
|
|
|
async def __aexit__(self, *args):
|
|
return False
|
|
|
|
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=lambda *a, **kw: CmFactory()):
|
|
s1 = await pool.get_session("server", "thread-1", {"transport": "stdio", "command": "x", "args": []})
|
|
s2 = await pool.get_session("server", "thread-2", {"transport": "stdio", "command": "x", "args": []})
|
|
|
|
assert s1 is not s2
|
|
assert s1 is sessions[0]
|
|
assert s2 is sessions[1]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_lru_eviction():
|
|
"""Oldest entries are evicted when the pool is full."""
|
|
pool = MCPSessionPool()
|
|
pool.MAX_SESSIONS = 2
|
|
|
|
class CmFactory:
|
|
def __init__(self):
|
|
self.closed = False
|
|
|
|
async def __aenter__(self):
|
|
return AsyncMock()
|
|
|
|
async def __aexit__(self, *args):
|
|
self.closed = True
|
|
return False
|
|
|
|
cms: list[CmFactory] = []
|
|
|
|
def make_cm(*a, **kw):
|
|
cm = CmFactory()
|
|
cms.append(cm)
|
|
return cm
|
|
|
|
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
|
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
|
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
|
# Pool is full (2). Adding t3 should evict t1.
|
|
await pool.get_session("s", "t3", {"transport": "stdio", "command": "x", "args": []})
|
|
|
|
assert cms[0].closed is True
|
|
assert cms[1].closed is False
|
|
assert cms[2].closed is False
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_scope():
|
|
"""close_scope shuts down sessions for a specific scope key."""
|
|
pool = MCPSessionPool()
|
|
|
|
class CmFactory:
|
|
def __init__(self):
|
|
self.closed = False
|
|
|
|
async def __aenter__(self):
|
|
return AsyncMock()
|
|
|
|
async def __aexit__(self, *args):
|
|
self.closed = True
|
|
return False
|
|
|
|
cms: list[CmFactory] = []
|
|
|
|
def make_cm(*a, **kw):
|
|
cm = CmFactory()
|
|
cms.append(cm)
|
|
return cm
|
|
|
|
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
|
await pool.get_session("s", "t1", {"transport": "stdio", "command": "x", "args": []})
|
|
await pool.get_session("s", "t2", {"transport": "stdio", "command": "x", "args": []})
|
|
|
|
await pool.close_scope("t1")
|
|
|
|
assert cms[0].closed is True
|
|
assert cms[1].closed is False
|
|
|
|
# t2 session still exists.
|
|
assert ("s", "t2") in pool._entries
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_close_all():
|
|
"""close_all shuts down every session."""
|
|
pool = MCPSessionPool()
|
|
|
|
class CmFactory:
|
|
def __init__(self):
|
|
self.closed = False
|
|
|
|
async def __aenter__(self):
|
|
return AsyncMock()
|
|
|
|
async def __aexit__(self, *args):
|
|
self.closed = True
|
|
return False
|
|
|
|
cms: list[CmFactory] = []
|
|
|
|
def make_cm(*a, **kw):
|
|
cm = CmFactory()
|
|
cms.append(cm)
|
|
return cm
|
|
|
|
with patch("langchain_mcp_adapters.sessions.create_session", side_effect=make_cm):
|
|
await pool.get_session("s1", "t1", {"transport": "stdio", "command": "x", "args": []})
|
|
await pool.get_session("s2", "t2", {"transport": "stdio", "command": "x", "args": []})
|
|
|
|
await pool.close_all()
|
|
|
|
assert all(cm.closed for cm in cms)
|
|
assert len(pool._entries) == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Singleton helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_get_session_pool_singleton():
|
|
"""get_session_pool returns the same instance."""
|
|
p1 = get_session_pool()
|
|
p2 = get_session_pool()
|
|
assert p1 is p2
|
|
|
|
|
|
def test_reset_session_pool():
|
|
"""reset_session_pool clears the singleton."""
|
|
p1 = get_session_pool()
|
|
reset_session_pool()
|
|
p2 = get_session_pool()
|
|
assert p1 is not p2
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration: _make_session_pool_tool uses the pool
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_pool_tool_wrapping():
|
|
"""The wrapper tool delegates to a pool-managed session."""
|
|
# Build a dummy StructuredTool (as returned by langchain-mcp-adapters).
|
|
from langchain_core.tools import StructuredTool
|
|
from pydantic import BaseModel, Field
|
|
|
|
from deerflow.mcp.tools import _make_session_pool_tool
|
|
|
|
class Args(BaseModel):
|
|
url: str = Field(..., description="url")
|
|
|
|
original_tool = StructuredTool(
|
|
name="playwright_navigate",
|
|
description="Navigate browser",
|
|
args_schema=Args,
|
|
coroutine=AsyncMock(),
|
|
response_format="content_and_artifact",
|
|
)
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
|
mock_cm = MagicMock()
|
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
connection = {"transport": "stdio", "command": "pw", "args": []}
|
|
|
|
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
|
wrapped = _make_session_pool_tool(original_tool, "playwright", connection)
|
|
|
|
# Simulate a tool call with a runtime context containing thread_id.
|
|
mock_runtime = MagicMock()
|
|
mock_runtime.context = {"thread_id": "thread-42"}
|
|
mock_runtime.config = {}
|
|
|
|
await wrapped.coroutine(runtime=mock_runtime, url="https://example.com")
|
|
|
|
mock_session.call_tool.assert_awaited_once_with("navigate", {"url": "https://example.com"})
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_pool_tool_extracts_thread_id():
|
|
"""Thread ID is extracted from runtime.config when not in context."""
|
|
from langchain_core.tools import StructuredTool
|
|
from pydantic import BaseModel, Field
|
|
|
|
from deerflow.mcp.tools import _make_session_pool_tool
|
|
|
|
class Args(BaseModel):
|
|
x: int = Field(..., description="x")
|
|
|
|
original_tool = StructuredTool(
|
|
name="server_tool",
|
|
description="test",
|
|
args_schema=Args,
|
|
coroutine=AsyncMock(),
|
|
response_format="content_and_artifact",
|
|
)
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
|
mock_cm = MagicMock()
|
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
|
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
|
|
|
mock_runtime = MagicMock()
|
|
mock_runtime.context = {}
|
|
mock_runtime.config = {"configurable": {"thread_id": "from-config"}}
|
|
|
|
await wrapped.coroutine(runtime=mock_runtime, x=1)
|
|
|
|
# Verify the session was created with the correct scope key.
|
|
pool = get_session_pool()
|
|
assert ("server", "from-config") in pool._entries
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_session_pool_tool_default_scope():
|
|
"""When no thread_id is available, 'default' is used as scope key."""
|
|
from langchain_core.tools import StructuredTool
|
|
from pydantic import BaseModel, Field
|
|
|
|
from deerflow.mcp.tools import _make_session_pool_tool
|
|
|
|
class Args(BaseModel):
|
|
x: int = Field(..., description="x")
|
|
|
|
original_tool = StructuredTool(
|
|
name="server_tool",
|
|
description="test",
|
|
args_schema=Args,
|
|
coroutine=AsyncMock(),
|
|
response_format="content_and_artifact",
|
|
)
|
|
|
|
mock_session = AsyncMock()
|
|
mock_session.call_tool = AsyncMock(return_value=MagicMock(content=[], isError=False, structuredContent=None))
|
|
mock_cm = MagicMock()
|
|
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
|
|
mock_cm.__aexit__ = AsyncMock(return_value=False)
|
|
|
|
with patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm):
|
|
wrapped = _make_session_pool_tool(original_tool, "server", {"transport": "stdio", "command": "x", "args": []})
|
|
|
|
# No thread_id in runtime at all.
|
|
await wrapped.coroutine(runtime=None, x=1)
|
|
|
|
pool = get_session_pool()
|
|
assert ("server", "default") in pool._entries
|