mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
feat(storage): add storage package base
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
from store.persistence.base_model import (
|
||||
Base,
|
||||
DataClassBase,
|
||||
DateTimeMixin,
|
||||
MappedBase,
|
||||
TimeZone,
|
||||
UniversalText,
|
||||
id_key,
|
||||
)
|
||||
|
||||
from .factory import (
|
||||
create_persistence,
|
||||
create_persistence_from_database_config,
|
||||
create_persistence_from_storage_config,
|
||||
storage_config_from_database_config,
|
||||
)
|
||||
from .types import AppPersistence
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"DataClassBase",
|
||||
"DateTimeMixin",
|
||||
"MappedBase",
|
||||
"TimeZone",
|
||||
"UniversalText",
|
||||
"id_key",
|
||||
"create_persistence",
|
||||
"create_persistence_from_database_config",
|
||||
"create_persistence_from_storage_config",
|
||||
"storage_config_from_database_config",
|
||||
"AppPersistence",
|
||||
]
|
||||
@@ -0,0 +1,107 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.utils import get_timezone
|
||||
|
||||
timezone = get_timezone()
|
||||
app_config = get_app_config()
|
||||
|
||||
# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT)
|
||||
_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger
|
||||
|
||||
id_key = Annotated[
|
||||
int,
|
||||
mapped_column(
|
||||
_id_type,
|
||||
primary_key=True,
|
||||
unique=True,
|
||||
index=True,
|
||||
autoincrement=True,
|
||||
sort_order=-999,
|
||||
comment="Primary key ID",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class UniversalText(TypeDecorator[str]):
|
||||
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
||||
|
||||
impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
|
||||
class TimeZone(TypeDecorator[datetime]):
|
||||
"""Timezone-aware datetime type compatible with PostgreSQL and MySQL."""
|
||||
|
||||
impl = DateTime(timezone=True)
|
||||
cache_ok = True
|
||||
|
||||
@property
|
||||
def python_type(self) -> type[datetime]:
|
||||
return datetime
|
||||
|
||||
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
||||
value = timezone.from_datetime(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
if value is not None and value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.tz_info)
|
||||
return value
|
||||
|
||||
|
||||
class DateTimeMixin(MappedAsDataclass):
|
||||
"""Mixin that adds created_time / updated_time columns."""
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=timezone.now,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
onupdate=timezone.now,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
|
||||
class MappedBase(AsyncAttrs, DeclarativeBase):
|
||||
"""Async-capable declarative base for all ORM models."""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(self) -> str:
|
||||
return self.__name__.lower()
|
||||
|
||||
@declared_attr.directive
|
||||
def __table_args__(self) -> dict:
|
||||
return {"comment": self.__doc__ or ""}
|
||||
|
||||
|
||||
class DataClassBase(MappedAsDataclass, MappedBase):
|
||||
"""Declarative base with native dataclass integration."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class Base(DataClassBase, DateTimeMixin):
|
||||
"""Declarative dataclass base with created_time / updated_time columns."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -0,0 +1,9 @@
|
||||
from .mysql import build_mysql_persistence
|
||||
from .postgres import build_postgres_persistence
|
||||
from .sqlite import build_sqlite_persistence
|
||||
|
||||
__all__ = [
|
||||
"build_postgres_persistence",
|
||||
"build_mysql_persistence",
|
||||
"build_sqlite_persistence",
|
||||
]
|
||||
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _validate_mysql_driver(db_url: str) -> str:
|
||||
url = make_url(db_url)
|
||||
driver = url.get_driver_name()
|
||||
|
||||
if driver not in {"aiomysql", "asyncmy"}:
|
||||
raise ValueError(
|
||||
"MySQL persistence requires async SQLAlchemy driver "
|
||||
f"(aiomysql/asyncmy), got: {driver!r}"
|
||||
)
|
||||
return driver
|
||||
|
||||
|
||||
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
_validate_mysql_driver(db_url)
|
||||
|
||||
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(db_url)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(db_url)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001
|
||||
cursor = dbapi_conn.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL;")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||||
cursor.execute("PRAGMA foreign_keys=ON;")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,121 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.config.storage_config import StorageConfig
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def storage_config_from_database_config(database_config: Any) -> StorageConfig:
|
||||
"""Convert the existing public DatabaseConfig shape to StorageConfig.
|
||||
|
||||
Storage only owns durable database-backed persistence. The app bridge
|
||||
should handle memory mode before calling into this package.
|
||||
"""
|
||||
backend = getattr(database_config, "backend", None)
|
||||
if backend == "sqlite":
|
||||
return StorageConfig(
|
||||
driver="sqlite",
|
||||
sqlite_dir=getattr(database_config, "sqlite_dir", ".deer-flow/data"),
|
||||
echo_sql=getattr(database_config, "echo_sql", False),
|
||||
pool_size=getattr(database_config, "pool_size", 5),
|
||||
)
|
||||
|
||||
if backend == "postgres":
|
||||
postgres_url = getattr(database_config, "postgres_url", "")
|
||||
if not postgres_url:
|
||||
raise ValueError("database.postgres_url is required when database.backend is 'postgres'")
|
||||
parsed = make_url(postgres_url)
|
||||
return StorageConfig(
|
||||
driver="postgres",
|
||||
database_url=postgres_url,
|
||||
username=parsed.username or "",
|
||||
password=parsed.password or "",
|
||||
host=parsed.host or "localhost",
|
||||
port=parsed.port or 5432,
|
||||
db_name=parsed.database or "deerflow",
|
||||
echo_sql=getattr(database_config, "echo_sql", False),
|
||||
pool_size=getattr(database_config, "pool_size", 5),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported database backend for storage persistence: {backend!r}")
|
||||
|
||||
|
||||
def _create_database_url(storage_config: StorageConfig) -> URL:
|
||||
"""Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres)."""
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
driver = "sqlite+aiosqlite"
|
||||
elif storage_config.driver == DataBaseType.mysql:
|
||||
driver = "mysql+aiomysql"
|
||||
elif storage_config.driver in (DataBaseType.postgresql, "postgres"):
|
||||
driver = "postgresql+asyncpg"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database driver: {storage_config.driver}")
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
import os
|
||||
|
||||
db_path = storage_config.sqlite_storage_path
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
database=db_path,
|
||||
)
|
||||
elif storage_config.database_url:
|
||||
url = make_url(storage_config.database_url)
|
||||
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
|
||||
url = url.set(drivername="postgresql+asyncpg")
|
||||
else:
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
username=storage_config.username,
|
||||
password=storage_config.password,
|
||||
host=storage_config.host,
|
||||
port=storage_config.port,
|
||||
database=storage_config.db_name or "deerflow",
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
async def create_persistence_from_storage_config(storage_config: StorageConfig) -> AppPersistence:
|
||||
from .drivers.mysql import build_mysql_persistence
|
||||
from .drivers.postgres import build_postgres_persistence
|
||||
from .drivers.sqlite import build_sqlite_persistence
|
||||
|
||||
driver = storage_config.driver
|
||||
db_url = _create_database_url(storage_config)
|
||||
|
||||
if driver in ("postgres", "postgresql"):
|
||||
return await build_postgres_persistence(
|
||||
db_url,
|
||||
echo=storage_config.echo_sql,
|
||||
pool_size=storage_config.pool_size,
|
||||
)
|
||||
|
||||
if driver == "mysql":
|
||||
return await build_mysql_persistence(
|
||||
db_url,
|
||||
echo=storage_config.echo_sql,
|
||||
pool_size=storage_config.pool_size,
|
||||
)
|
||||
|
||||
if driver == "sqlite":
|
||||
return await build_sqlite_persistence(db_url, echo=storage_config.echo_sql)
|
||||
|
||||
raise ValueError(f"Unsupported database driver: {driver}")
|
||||
|
||||
|
||||
async def create_persistence_from_database_config(database_config: Any) -> AppPersistence:
|
||||
storage_config = storage_config_from_database_config(database_config)
|
||||
return await create_persistence_from_storage_config(storage_config)
|
||||
|
||||
|
||||
async def create_persistence() -> AppPersistence:
|
||||
app_config = get_app_config()
|
||||
return await create_persistence_from_storage_config(app_config.storage)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .close import close_in_order
|
||||
|
||||
__all__ = ["close_in_order"]
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
AsyncCloser = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
async def close_in_order(*closers: AsyncCloser) -> None:
|
||||
"""
|
||||
Run async closers in order and raise the first error, if any.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Used to keep driver-specific close logic readable.
|
||||
- We intentionally do not stop at first failure, so later resources
|
||||
still get a chance to close.
|
||||
"""
|
||||
first_error: Exception | None = None
|
||||
|
||||
for closer in closers:
|
||||
try:
|
||||
await closer()
|
||||
except Exception as exc:
|
||||
if first_error is None:
|
||||
first_error = exc
|
||||
|
||||
if first_error is not None:
|
||||
raise first_error
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
AsyncSetup = Callable[[], Awaitable[None]]
|
||||
AsyncClose = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AppPersistence:
|
||||
"""
|
||||
Unified runtime persistence bundle.
|
||||
"""
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
setup: AsyncSetup
|
||||
aclose: AsyncClose
|
||||
Reference in New Issue
Block a user