77 lines
2.1 KiB
Python
77 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
from sqlalchemy import URL
|
|
from sqlalchemy.engine import make_url
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from store.persistence import MappedBase
|
|
from store.persistence.shared import close_in_order
|
|
from store.persistence.types import AppPersistence
|
|
|
|
|
|
def _validate_mysql_driver(db_url: URL) -> str:
|
|
url = make_url(db_url)
|
|
driver = url.get_driver_name()
|
|
|
|
if driver not in {"aiomysql", "asyncmy"}:
|
|
raise ValueError(f"MySQL persistence requires async SQLAlchemy driver (aiomysql/asyncmy), got: {driver!r}")
|
|
return driver
|
|
|
|
|
|
def _checkpoint_conn_string(db_url: URL) -> str:
|
|
return db_url.render_as_string(hide_password=False)
|
|
|
|
|
|
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
|
_validate_mysql_driver(db_url)
|
|
|
|
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
|
|
|
|
import store.repositories.models # noqa: F401
|
|
|
|
engine = create_async_engine(
|
|
db_url,
|
|
echo=echo,
|
|
future=True,
|
|
pool_pre_ping=True,
|
|
pool_size=pool_size,
|
|
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
|
)
|
|
|
|
session_factory = async_sessionmaker(
|
|
bind=engine,
|
|
class_=AsyncSession,
|
|
expire_on_commit=False,
|
|
autoflush=False,
|
|
)
|
|
|
|
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
|
checkpointer = await saver_cm.__aenter__()
|
|
|
|
async def setup() -> None:
|
|
# 1. LangGraph checkpoint tables / migrations
|
|
await checkpointer.setup()
|
|
|
|
# 2. ORM business tables
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(MappedBase.metadata.create_all)
|
|
|
|
async def _close_saver() -> None:
|
|
await saver_cm.__aexit__(None, None, None)
|
|
|
|
async def aclose() -> None:
|
|
await close_in_order(
|
|
engine.dispose,
|
|
_close_saver,
|
|
)
|
|
|
|
return AppPersistence(
|
|
checkpointer=checkpointer,
|
|
engine=engine,
|
|
session_factory=session_factory,
|
|
setup=setup,
|
|
aclose=aclose,
|
|
)
|