Compare commits

..

2 Commits

Author SHA1 Message Date
Willem Jiang c1af6cc4fc Potential fix for pull request finding
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-26 10:56:05 +08:00
Willem Jiang 761a535d6b fix(checkpointer): use AsyncConnectionPool for postgres to prevent stale connection errors (#3223)
Replace AsyncPostgresSaver.from_conn_string() with an explicit
  AsyncConnectionPool that has check_connection enabled, so dead idle
  connections are detected and replaced on checkout instead of raising
  OperationalError.
2026-05-26 10:02:16 +08:00
4 changed files with 77 additions and 92 deletions
+4 -13
View File
@@ -1,4 +1,4 @@
"""Load MCP tools using langchain-mcp-adapters with stdio session pooling.""" """Load MCP tools using langchain-mcp-adapters with persistent sessions."""
from __future__ import annotations from __future__ import annotations
@@ -173,10 +173,8 @@ def _make_session_pool_tool(
async def get_mcp_tools() -> list[BaseTool]: async def get_mcp_tools() -> list[BaseTool]:
"""Get all tools from enabled MCP servers. """Get all tools from enabled MCP servers.
Tools using stdio transport are wrapped with persistent-session logic so Tools are wrapped with persistent-session logic so that consecutive
consecutive calls within the same thread reuse the same MCP session. calls within the same thread reuse the same MCP session.
HTTP/SSE tools are returned unwrapped to avoid cross-task TaskGroup
cleanup errors.
Returns: Returns:
List of LangChain tools from all enabled MCP servers. List of LangChain tools from all enabled MCP servers.
@@ -253,9 +251,6 @@ async def get_mcp_tools() -> list[BaseTool]:
logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers") logger.info(f"Successfully loaded {len(tools)} tool(s) from MCP servers")
# Wrap each tool with persistent-session logic. # Wrap each tool with persistent-session logic.
# Only pool stdio sessions. HTTP/SSE transports use anyio TaskGroups
# internally which cannot be closed from a different async task, so
# pooling them causes RuntimeError on cleanup (see #3203).
wrapped_tools: list[BaseTool] = [] wrapped_tools: list[BaseTool] = []
for tool in tools: for tool in tools:
tool_server: str | None = None tool_server: str | None = None
@@ -265,11 +260,7 @@ async def get_mcp_tools() -> list[BaseTool]:
break break
if tool_server is not None: if tool_server is not None:
transport = servers_config[tool_server].get("transport", "stdio") wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
if transport == "stdio":
wrapped_tools.append(_make_session_pool_tool(tool, tool_server, servers_config[tool_server], tool_interceptors))
else:
wrapped_tools.append(tool)
else: else:
wrapped_tools.append(tool) wrapped_tools.append(tool)
@@ -67,10 +67,22 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
except ImportError as exc: except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc raise ImportError(POSTGRES_INSTALL) from exc
try:
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not config.connection_string: if not config.connection_string:
raise ValueError(POSTGRES_CONN_REQUIRED) raise ValueError(POSTGRES_CONN_REQUIRED)
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver: pool = AsyncConnectionPool(
config.connection_string,
kwargs={"autocommit": True, "prepare_threshold": 0, "row_factory": dict_row},
check=AsyncConnectionPool.check_connection,
)
async with pool:
saver = AsyncPostgresSaver(conn=pool)
await saver.setup() await saver.setup()
yield saver yield saver
return return
@@ -111,10 +123,22 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
except ImportError as exc: except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc raise ImportError(POSTGRES_INSTALL) from exc
try:
from psycopg.rows import dict_row
from psycopg_pool import AsyncConnectionPool
except ImportError as exc:
raise ImportError(POSTGRES_INSTALL) from exc
if not db_config.postgres_url: if not db_config.postgres_url:
raise ValueError("database.postgres_url is required for the postgres backend") raise ValueError("database.postgres_url is required for the postgres backend")
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver: pool = AsyncConnectionPool(
db_config.postgres_url,
kwargs={"autocommit": True, "prepare_threshold": 0, "row_factory": dict_row},
check=AsyncConnectionPool.check_connection,
)
async with pool:
saver = AsyncPostgresSaver(conn=pool)
await saver.setup() await saver.setup()
yield saver yield saver
return return
+47
View File
@@ -326,6 +326,53 @@ class TestAsyncCheckpointer:
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db") mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db")
mock_saver.setup.assert_awaited_once() mock_saver.setup.assert_awaited_once()
@pytest.mark.anyio
async def test_postgres_uses_connection_pool(self):
"""Async postgres checkpointer should use AsyncConnectionPool, not a single connection."""
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
mock_config = MagicMock()
mock_config.checkpointer = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")
mock_saver = AsyncMock()
mock_saver_cls = MagicMock(return_value=mock_saver)
mock_pool_instance = AsyncMock()
mock_pool_instance.__aenter__.return_value = mock_pool_instance
mock_pool_instance.__aexit__.return_value = False
mock_pool_cls = MagicMock(return_value=mock_pool_instance)
mock_pool_cls.check_connection = AsyncMock()
mock_dict_row = MagicMock()
mock_pg_module = MagicMock()
mock_pg_module.AsyncPostgresSaver = mock_saver_cls
mock_psycopg_rows = MagicMock()
mock_psycopg_rows.dict_row = mock_dict_row
with (
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
patch.dict(sys.modules, {"langgraph.checkpoint.postgres.aio": mock_pg_module}),
patch.dict(sys.modules, {"psycopg.rows": mock_psycopg_rows}),
patch.dict(sys.modules, {"psycopg_pool": MagicMock(AsyncConnectionPool=mock_pool_cls)}),
):
# AsyncConnectionPool() is a callable that returns mock_pool_instance
# We need the constructor to be an async context manager
async with make_checkpointer() as saver:
assert saver is mock_saver
# Verify the pool was constructed with check Connection
mock_pool_cls.assert_called_once()
call_kwargs = mock_pool_cls.call_args
assert call_kwargs[0][0] == "postgresql://localhost/db"
assert call_kwargs[1]["check"] is mock_pool_cls.check_connection
# Verify saver was constructed with the pool (not via from_conn_string)
mock_saver_cls.assert_called_once_with(conn=mock_pool_instance)
mock_saver.setup.assert_awaited_once()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# app_config.py integration # app_config.py integration
-77
View File
@@ -407,80 +407,3 @@ def test_session_pool_tool_sync_wrapper_path_is_safe():
wrapped.func(url="https://example.com") wrapped.func(url="https://example.com")
mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"}) mock_session.call_tool.assert_called_once_with("navigate", {"url": "https://example.com"})
# ---------------------------------------------------------------------------
# get_mcp_tools: HTTP transport should NOT be pooled
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_http_transport_tools_not_pooled():
"""HTTP/SSE transport tools should NOT be wrapped with the session pool."""
from langchain_core.tools import StructuredTool
from pydantic import BaseModel, Field
from deerflow.mcp.tools import get_mcp_tools
class Args(BaseModel):
query: str = Field(..., description="query")
http_tool = StructuredTool(
name="myserver_search",
description="Search tool",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
stdio_tool = StructuredTool(
name="playwright_navigate",
description="Navigate browser",
args_schema=Args,
coroutine=AsyncMock(),
response_format="content_and_artifact",
)
mock_session = AsyncMock()
mock_cm = MagicMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_session)
mock_cm.__aexit__ = AsyncMock(return_value=False)
extensions_config = MagicMock()
extensions_config.get_enabled_mcp_servers.return_value = {
"myserver": MagicMock(type="http", url="http://localhost:8000/mcp", headers=None, command=None, args=[], env=None),
"playwright": MagicMock(type="stdio", command="npx", args=["-y", "@anthropic/mcp-server-playwright"], env=None, url=None, headers=None),
}
extensions_config.model_extra = {}
servers_config = {
"myserver": {"transport": "http", "url": "http://localhost:8000/mcp"},
"playwright": {"transport": "stdio", "command": "npx", "args": ["-y", "@anthropic/mcp-server-playwright"]},
}
with (
patch("deerflow.mcp.tools.ExtensionsConfig.from_file", return_value=extensions_config),
patch("deerflow.mcp.tools.build_servers_config", return_value=servers_config),
patch("deerflow.mcp.tools.get_initial_oauth_headers", return_value={}),
patch("deerflow.mcp.tools.build_oauth_tool_interceptor", return_value=None),
patch("langchain_mcp_adapters.client.MultiServerMCPClient") as MockClient,
patch("langchain_mcp_adapters.sessions.create_session", return_value=mock_cm),
):
mock_client_instance = MockClient.return_value
mock_client_instance.get_tools = AsyncMock(return_value=[http_tool, stdio_tool])
tools = await get_mcp_tools()
pool = get_session_pool()
# Tool discovery is lazy: no pooled sessions are created until a wrapped tool is invoked.
assert list(pool._entries.keys()) == []
# Verify the HTTP tool was NOT wrapped with the pool (it's the original tool).
http_tools = [t for t in tools if t.name == "myserver_search"]
assert len(http_tools) == 1
assert http_tools[0].coroutine is http_tool.coroutine
# Verify the stdio tool WAS wrapped with the pool.
stdio_tools = [t for t in tools if t.name == "playwright_navigate"]
assert len(stdio_tools) == 1
assert stdio_tools[0].coroutine is not stdio_tool.coroutine