mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
fix(checkpointer): use AsyncConnectionPool for postgres to prevent stale connection errors (#3223) (#3226)
* fix(checkpointer): use AsyncConnectionPool for postgres to prevent stale connection errors (#3223) Replace AsyncPostgresSaver.from_conn_string() with an explicit AsyncConnectionPool that has check_connection enabled, so dead idle connections are detected and replaced on checkout instead of raising OperationalError. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * Fixed the unit test error and lint error * fix(checkpointer): add TCP keepalive to postgres connection pool (#3254) Enable TCP keepalive probes on the AsyncConnectionPool to prevent idle postgres connections from being dropped by the server or network middleware. Combined with the existing check_connection callback, this provides defense-in-depth against stale connection errors. Fixes #3254 * Changed the code as review suggestion --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -47,6 +47,41 @@ def _prepare_database_sqlite_checkpointer_path(db_config) -> str:
|
|||||||
return conn_str
|
return conn_str
|
||||||
|
|
||||||
|
|
||||||
|
def _build_postgres_pool(conn_string: str):
|
||||||
|
"""Build an AsyncConnectionPool with TCP keepalive and connection checking."""
|
||||||
|
from psycopg.rows import dict_row
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
|
||||||
|
return AsyncConnectionPool(
|
||||||
|
conn_string,
|
||||||
|
kwargs={
|
||||||
|
"autocommit": True,
|
||||||
|
"prepare_threshold": 0,
|
||||||
|
"row_factory": dict_row,
|
||||||
|
"keepalives": 1,
|
||||||
|
"keepalives_idle": 60,
|
||||||
|
"keepalives_interval": 10,
|
||||||
|
"keepalives_count": 6,
|
||||||
|
},
|
||||||
|
check=AsyncConnectionPool.check_connection,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_postgres_imports():
|
||||||
|
"""Import and return (AsyncPostgresSaver, AsyncConnectionPool), raising ImportError on failure."""
|
||||||
|
try:
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(POSTGRES_INSTALL) from exc
|
||||||
|
|
||||||
|
try:
|
||||||
|
from psycopg_pool import AsyncConnectionPool
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(POSTGRES_INSTALL) from exc
|
||||||
|
|
||||||
|
return AsyncPostgresSaver, AsyncConnectionPool
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Async factory
|
# Async factory
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -74,15 +109,13 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
|||||||
return
|
return
|
||||||
|
|
||||||
if config.type == "postgres":
|
if config.type == "postgres":
|
||||||
try:
|
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(POSTGRES_INSTALL) from exc
|
|
||||||
|
|
||||||
if not config.connection_string:
|
if not config.connection_string:
|
||||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||||
|
|
||||||
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
|
AsyncPostgresSaver, _ = _ensure_postgres_imports()
|
||||||
|
pool = _build_postgres_pool(config.connection_string)
|
||||||
|
async with pool:
|
||||||
|
saver = AsyncPostgresSaver(conn=pool)
|
||||||
await saver.setup()
|
await saver.setup()
|
||||||
yield saver
|
yield saver
|
||||||
return
|
return
|
||||||
@@ -117,15 +150,13 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
|
|||||||
return
|
return
|
||||||
|
|
||||||
if db_config.backend == "postgres":
|
if db_config.backend == "postgres":
|
||||||
try:
|
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(POSTGRES_INSTALL) from exc
|
|
||||||
|
|
||||||
if not db_config.postgres_url:
|
if not db_config.postgres_url:
|
||||||
raise ValueError("database.postgres_url is required for the postgres backend")
|
raise ValueError("database.postgres_url is required for the postgres backend")
|
||||||
|
|
||||||
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
|
AsyncPostgresSaver, _ = _ensure_postgres_imports()
|
||||||
|
pool = _build_postgres_pool(db_config.postgres_url)
|
||||||
|
async with pool:
|
||||||
|
saver = AsyncPostgresSaver(conn=pool)
|
||||||
await saver.setup()
|
await saver.setup()
|
||||||
yield saver
|
yield saver
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -326,6 +326,99 @@ class TestAsyncCheckpointer:
|
|||||||
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db")
|
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db")
|
||||||
mock_saver.setup.assert_awaited_once()
|
mock_saver.setup.assert_awaited_once()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_postgres_uses_connection_pool(self):
|
||||||
|
"""Async postgres checkpointer should use AsyncConnectionPool, not a single connection."""
|
||||||
|
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||||
|
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.checkpointer = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")
|
||||||
|
|
||||||
|
mock_saver = AsyncMock()
|
||||||
|
|
||||||
|
mock_saver_cls = MagicMock(return_value=mock_saver)
|
||||||
|
|
||||||
|
mock_pool_instance = AsyncMock()
|
||||||
|
mock_pool_instance.__aenter__.return_value = mock_pool_instance
|
||||||
|
mock_pool_instance.__aexit__.return_value = False
|
||||||
|
|
||||||
|
mock_pool_cls = MagicMock(return_value=mock_pool_instance)
|
||||||
|
mock_pool_cls.check_connection = AsyncMock()
|
||||||
|
mock_dict_row = MagicMock()
|
||||||
|
|
||||||
|
mock_pg_module = MagicMock()
|
||||||
|
mock_pg_module.AsyncPostgresSaver = mock_saver_cls
|
||||||
|
|
||||||
|
mock_psycopg_rows = MagicMock()
|
||||||
|
mock_psycopg_rows.dict_row = mock_dict_row
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||||
|
patch.dict(sys.modules, {"langgraph.checkpoint.postgres.aio": mock_pg_module}),
|
||||||
|
patch.dict(sys.modules, {"psycopg.rows": mock_psycopg_rows}),
|
||||||
|
patch.dict(sys.modules, {"psycopg_pool": MagicMock(AsyncConnectionPool=mock_pool_cls)}),
|
||||||
|
):
|
||||||
|
# AsyncConnectionPool() is a callable that returns mock_pool_instance
|
||||||
|
# We need the constructor to be an async context manager
|
||||||
|
async with make_checkpointer() as saver:
|
||||||
|
assert saver is mock_saver
|
||||||
|
|
||||||
|
# Verify the pool was constructed with check Connection
|
||||||
|
mock_pool_cls.assert_called_once()
|
||||||
|
call_kwargs = mock_pool_cls.call_args
|
||||||
|
assert call_kwargs[0][0] == "postgresql://localhost/db"
|
||||||
|
assert call_kwargs[1]["check"] is mock_pool_cls.check_connection
|
||||||
|
|
||||||
|
# Verify saver was constructed with the pool (not via from_conn_string)
|
||||||
|
mock_saver_cls.assert_called_once_with(conn=mock_pool_instance)
|
||||||
|
mock_saver.setup.assert_awaited_once()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_database_postgres_uses_connection_pool(self):
|
||||||
|
"""Unified database postgres path should use AsyncConnectionPool with keepalive."""
|
||||||
|
from deerflow.config.database_config import DatabaseConfig
|
||||||
|
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||||
|
|
||||||
|
db_config = DatabaseConfig(backend="postgres", postgres_url="postgresql://localhost/db")
|
||||||
|
mock_config = MagicMock()
|
||||||
|
mock_config.checkpointer = None
|
||||||
|
mock_config.database = db_config
|
||||||
|
|
||||||
|
mock_saver = AsyncMock()
|
||||||
|
|
||||||
|
mock_saver_cls = MagicMock(return_value=mock_saver)
|
||||||
|
|
||||||
|
mock_pool_instance = AsyncMock()
|
||||||
|
mock_pool_instance.__aenter__.return_value = mock_pool_instance
|
||||||
|
mock_pool_instance.__aexit__.return_value = False
|
||||||
|
|
||||||
|
mock_pool_cls = MagicMock(return_value=mock_pool_instance)
|
||||||
|
mock_pool_cls.check_connection = AsyncMock()
|
||||||
|
mock_dict_row = MagicMock()
|
||||||
|
|
||||||
|
mock_pg_module = MagicMock()
|
||||||
|
mock_pg_module.AsyncPostgresSaver = mock_saver_cls
|
||||||
|
|
||||||
|
mock_psycopg_rows = MagicMock()
|
||||||
|
mock_psycopg_rows.dict_row = mock_dict_row
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||||
|
patch.dict(sys.modules, {"langgraph.checkpoint.postgres.aio": mock_pg_module}),
|
||||||
|
patch.dict(sys.modules, {"psycopg.rows": mock_psycopg_rows}),
|
||||||
|
patch.dict(sys.modules, {"psycopg_pool": MagicMock(AsyncConnectionPool=mock_pool_cls)}),
|
||||||
|
):
|
||||||
|
async with make_checkpointer() as saver:
|
||||||
|
assert saver is mock_saver
|
||||||
|
|
||||||
|
mock_pool_cls.assert_called_once()
|
||||||
|
call_kwargs = mock_pool_cls.call_args
|
||||||
|
assert call_kwargs[0][0] == "postgresql://localhost/db"
|
||||||
|
assert call_kwargs[1]["check"] is mock_pool_cls.check_connection
|
||||||
|
|
||||||
|
mock_saver_cls.assert_called_once_with(conn=mock_pool_instance)
|
||||||
|
mock_saver.setup.assert_awaited_once()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_database_sqlite_creates_parent_dir_via_to_thread(self):
|
async def test_database_sqlite_creates_parent_dir_via_to_thread(self):
|
||||||
"""Unified database SQLite setup should also move path IO off the event loop."""
|
"""Unified database SQLite setup should also move path IO off the event loop."""
|
||||||
|
|||||||
Reference in New Issue
Block a user