mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
feat(storage): add storage package base
This commit is contained in:
@@ -0,0 +1,9 @@
|
||||
from .mysql import build_mysql_persistence
|
||||
from .postgres import build_postgres_persistence
|
||||
from .sqlite import build_sqlite_persistence
|
||||
|
||||
__all__ = [
|
||||
"build_postgres_persistence",
|
||||
"build_mysql_persistence",
|
||||
"build_sqlite_persistence",
|
||||
]
|
||||
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _validate_mysql_driver(db_url: str) -> str:
|
||||
url = make_url(db_url)
|
||||
driver = url.get_driver_name()
|
||||
|
||||
if driver not in {"aiomysql", "asyncmy"}:
|
||||
raise ValueError(
|
||||
"MySQL persistence requires async SQLAlchemy driver "
|
||||
f"(aiomysql/asyncmy), got: {driver!r}"
|
||||
)
|
||||
return driver
|
||||
|
||||
|
||||
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
_validate_mysql_driver(db_url)
|
||||
|
||||
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(db_url)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(db_url)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001
|
||||
cursor = dbapi_conn.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL;")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||||
cursor.execute("PRAGMA foreign_keys=ON;")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
Reference in New Issue
Block a user