mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-26 18:06:00 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c1af6cc4fc | |||
| 761a535d6b |
@@ -67,10 +67,22 @@ async def _async_checkpointer(config) -> AsyncIterator[Checkpointer]:
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
try:
|
||||
from psycopg.rows import dict_row
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not config.connection_string:
|
||||
raise ValueError(POSTGRES_CONN_REQUIRED)
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(config.connection_string) as saver:
|
||||
pool = AsyncConnectionPool(
|
||||
config.connection_string,
|
||||
kwargs={"autocommit": True, "prepare_threshold": 0, "row_factory": dict_row},
|
||||
check=AsyncConnectionPool.check_connection,
|
||||
)
|
||||
async with pool:
|
||||
saver = AsyncPostgresSaver(conn=pool)
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
@@ -111,10 +123,22 @@ async def _async_checkpointer_from_database(db_config) -> AsyncIterator[Checkpoi
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
try:
|
||||
from psycopg.rows import dict_row
|
||||
from psycopg_pool import AsyncConnectionPool
|
||||
except ImportError as exc:
|
||||
raise ImportError(POSTGRES_INSTALL) from exc
|
||||
|
||||
if not db_config.postgres_url:
|
||||
raise ValueError("database.postgres_url is required for the postgres backend")
|
||||
|
||||
async with AsyncPostgresSaver.from_conn_string(db_config.postgres_url) as saver:
|
||||
pool = AsyncConnectionPool(
|
||||
db_config.postgres_url,
|
||||
kwargs={"autocommit": True, "prepare_threshold": 0, "row_factory": dict_row},
|
||||
check=AsyncConnectionPool.check_connection,
|
||||
)
|
||||
async with pool:
|
||||
saver = AsyncPostgresSaver(conn=pool)
|
||||
await saver.setup()
|
||||
yield saver
|
||||
return
|
||||
|
||||
@@ -326,6 +326,53 @@ class TestAsyncCheckpointer:
|
||||
mock_saver_cls.from_conn_string.assert_called_once_with("/tmp/resolved/test.db")
|
||||
mock_saver.setup.assert_awaited_once()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_postgres_uses_connection_pool(self):
|
||||
"""Async postgres checkpointer should use AsyncConnectionPool, not a single connection."""
|
||||
from deerflow.runtime.checkpointer.async_provider import make_checkpointer
|
||||
|
||||
mock_config = MagicMock()
|
||||
mock_config.checkpointer = CheckpointerConfig(type="postgres", connection_string="postgresql://localhost/db")
|
||||
|
||||
mock_saver = AsyncMock()
|
||||
|
||||
mock_saver_cls = MagicMock(return_value=mock_saver)
|
||||
|
||||
mock_pool_instance = AsyncMock()
|
||||
mock_pool_instance.__aenter__.return_value = mock_pool_instance
|
||||
mock_pool_instance.__aexit__.return_value = False
|
||||
|
||||
mock_pool_cls = MagicMock(return_value=mock_pool_instance)
|
||||
mock_pool_cls.check_connection = AsyncMock()
|
||||
mock_dict_row = MagicMock()
|
||||
|
||||
mock_pg_module = MagicMock()
|
||||
mock_pg_module.AsyncPostgresSaver = mock_saver_cls
|
||||
|
||||
mock_psycopg_rows = MagicMock()
|
||||
mock_psycopg_rows.dict_row = mock_dict_row
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.checkpointer.async_provider.get_app_config", return_value=mock_config),
|
||||
patch.dict(sys.modules, {"langgraph.checkpoint.postgres.aio": mock_pg_module}),
|
||||
patch.dict(sys.modules, {"psycopg.rows": mock_psycopg_rows}),
|
||||
patch.dict(sys.modules, {"psycopg_pool": MagicMock(AsyncConnectionPool=mock_pool_cls)}),
|
||||
):
|
||||
# AsyncConnectionPool() is a callable that returns mock_pool_instance
|
||||
# We need the constructor to be an async context manager
|
||||
async with make_checkpointer() as saver:
|
||||
assert saver is mock_saver
|
||||
|
||||
# Verify the pool was constructed with check Connection
|
||||
mock_pool_cls.assert_called_once()
|
||||
call_kwargs = mock_pool_cls.call_args
|
||||
assert call_kwargs[0][0] == "postgresql://localhost/db"
|
||||
assert call_kwargs[1]["check"] is mock_pool_cls.check_connection
|
||||
|
||||
# Verify saver was constructed with the pool (not via from_conn_string)
|
||||
mock_saver_cls.assert_called_once_with(conn=mock_pool_instance)
|
||||
mock_saver.setup.assert_awaited_once()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# app_config.py integration
|
||||
|
||||
Reference in New Issue
Block a user