diff --git a/backend/packages/storage/pyproject.toml b/backend/packages/storage/pyproject.toml new file mode 100644 index 000000000..71d04701a --- /dev/null +++ b/backend/packages/storage/pyproject.toml @@ -0,0 +1,35 @@ +[project] +name = "deerflow-storage" +version = "0.1.0" +description = "DeerFlow storage framework" +requires-python = ">=3.12" +dependencies = [ + "dotenv>=0.9.9", + "pydantic>=2.12.5", + "pyyaml>=6.0.3", + "sqlalchemy[asyncio]>=2.0,<3.0", + "alembic>=1.13", + "langgraph>=1.1.9", +] +[project.optional-dependencies] +postgres = [ + "asyncpg>=0.29", + "langgraph-checkpoint-postgres>=3.0.5", + "psycopg[binary]>=3.3.3", + "psycopg-pool>=3.3.0", +] +mysql = [ + "aiomysql>=0.2", + "langgraph-checkpoint-mysql>=3.0.0", +] +sqlite = [ + "aiosqlite>=0.22.1", + "langgraph-checkpoint-sqlite>=3.0.3" +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["store"] diff --git a/backend/packages/storage/store/__init__.py b/backend/packages/storage/store/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/packages/storage/store/common/__init__.py b/backend/packages/storage/store/common/__init__.py new file mode 100644 index 000000000..e21d63c5d --- /dev/null +++ b/backend/packages/storage/store/common/__init__.py @@ -0,0 +1,5 @@ +from .enums import DataBaseType + +__all__ = [ + 'DataBaseType', +] diff --git a/backend/packages/storage/store/common/enums.py b/backend/packages/storage/store/common/enums.py new file mode 100644 index 000000000..1df841835 --- /dev/null +++ b/backend/packages/storage/store/common/enums.py @@ -0,0 +1,41 @@ +from enum import Enum +from enum import IntEnum as SourceIntEnum +from enum import StrEnum as SourceStrEnum +from typing import Any, TypeVar + +T = TypeVar('T', bound=Enum) + + +class _EnumBase: + """Base enum class with common utility methods.""" + + @classmethod + def get_member_keys(cls) -> list[str]: + """Return a list of enum member names.""" + return list(cls.__members__.keys()) + + @classmethod + def get_member_values(cls) -> list: + """Return a list of enum member values.""" + return [item.value for item in cls.__members__.values()] + + @classmethod + def get_member_dict(cls) -> dict[str, Any]: + """Return a dict mapping member names to values.""" + return {name: item.value for name, item in cls.__members__.items()} + + +class IntEnum(_EnumBase, SourceIntEnum): + """Integer enum base class.""" + + +class StrEnum(_EnumBase, SourceStrEnum): + """String enum base class.""" + + +class DataBaseType(StrEnum): + """Database type.""" + + sqlite = 'sqlite' + mysql = 'mysql' + postgresql = 'postgresql' diff --git a/backend/packages/storage/store/config/__init__.py b/backend/packages/storage/store/config/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/packages/storage/store/config/app_config.py b/backend/packages/storage/store/config/app_config.py new file mode 100644 index 000000000..fc61d7b02 --- /dev/null +++ b/backend/packages/storage/store/config/app_config.py @@ -0,0 +1,290 @@ +import logging +import os +from contextvars import ContextVar +from pathlib import Path +from typing import Any, Self + +import yaml +from dotenv import load_dotenv +from pydantic import BaseModel, ConfigDict, Field + +from store.config.storage_config import StorageConfig + +load_dotenv() + +logger = logging.getLogger(__name__) + + +def _default_config_candidates() -> tuple[Path, ...]: + """Return deterministic config.yaml locations without relying on cwd.""" + backend_dir = Path(__file__).resolve().parents[4] + repo_root = backend_dir.parent + cwd = Path.cwd().resolve() + candidates = ( + cwd / "config.yaml", + backend_dir / "config.yaml", + repo_root / "config.yaml", + ) + return tuple(dict.fromkeys(candidates)) + + +def _storage_from_database_config(config_data: dict[str, Any]) -> None: + """Keep the existing public `database:` config compatible with storage.""" + if "storage" in config_data: + return + + database = config_data.get("database") + if not isinstance(database, dict): + return + + backend = database.get("backend") + if backend == "memory": + raise ValueError("database.backend='memory' is not supported by storage; handle memory mode before loading storage config") + + storage: dict[str, Any] = { + "driver": "postgres" if backend == "postgres" else backend, + "sqlite_dir": database.get("sqlite_dir", ".deer-flow/data"), + "echo_sql": database.get("echo_sql", False), + "pool_size": database.get("pool_size", 5), + } + + postgres_url = database.get("postgres_url") + if backend == "postgres" and isinstance(postgres_url, str) and postgres_url: + from sqlalchemy.engine.url import make_url + + parsed = make_url(postgres_url) + storage["database_url"] = postgres_url + storage.update( + { + "username": parsed.username or "", + "password": parsed.password or "", + "host": parsed.host or "localhost", + "port": parsed.port or 5432, + "db_name": parsed.database or "deerflow", + } + ) + + config_data["storage"] = storage + + +class AppConfig(BaseModel): + """DeerFlow application configuration.""" + + timezone: str = Field(default="UTC", description="Timezone for scheduling and timestamps (e.g. 'UTC', 'America/New_York')") + log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)") + storage: StorageConfig = Field(default=StorageConfig()) + model_config = ConfigDict(extra="allow", frozen=False) + + @classmethod + def resolve_config_path(cls, config_path: str | None = None) -> Path: + """Resolve the config file path. + + Priority: + 1. If provided `config_path` argument, use it. + 2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it. + 3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`. + """ + if config_path: + path = Path(config_path) + if not Path.exists(path): + raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}") + return path + elif os.getenv("DEER_FLOW_CONFIG_PATH"): + path = Path(os.getenv("DEER_FLOW_CONFIG_PATH")) + if not Path.exists(path): + raise FileNotFoundError( + f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}") + return path + else: + for path in _default_config_candidates(): + if path.exists(): + return path + raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations") + + @classmethod + def from_file(cls, config_path: str | None = None) -> Self: + """Load and validate config from YAML. See `resolve_config_path` for path resolution.""" + resolved_path = cls.resolve_config_path(config_path) + with open(resolved_path, encoding="utf-8") as f: + config_data = yaml.safe_load(f) or {} + + cls._check_config_version(config_data, resolved_path) + + config_data = cls.resolve_env_variables(config_data) + _storage_from_database_config(config_data) + + if os.getenv("TIMEZONE"): + config_data["timezone"] = os.getenv("TIMEZONE") + + result = cls.model_validate(config_data) + return result + + @classmethod + def _check_config_version(cls, config_data: dict, config_path: Path) -> None: + """Check if the user's config.yaml is outdated compared to config.example.yaml. + + Emits a warning if the user's config_version is lower than the example's. + Missing config_version is treated as version 0 (pre-versioning). + """ + try: + user_version = int(config_data.get("config_version", 0)) + except (TypeError, ValueError): + user_version = 0 + + # Find config.example.yaml by searching config.yaml's directory and its parents + example_path = None + search_dir = config_path.parent + for _ in range(5): # search up to 5 levels + candidate = search_dir / "config.example.yaml" + if candidate.exists(): + example_path = candidate + break + parent = search_dir.parent + if parent == search_dir: + break + search_dir = parent + if example_path is None: + return + + try: + with open(example_path, encoding="utf-8") as f: + example_data = yaml.safe_load(f) + raw = example_data.get("config_version", 0) if example_data else 0 + try: + example_version = int(raw) + except (TypeError, ValueError): + example_version = 0 + except Exception: + return + + if user_version < example_version: + logger.warning( + "Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to " + "merge new fields into your config.", + user_version, + example_version, + ) + + @classmethod + def resolve_env_variables(cls, config: Any) -> Any: + """Recursively replace $VAR strings with their environment variable values (e.g. $OPENAI_API_KEY).""" + if isinstance(config, str): + if config.startswith("$"): + env_value = os.getenv(config[1:]) + if env_value is None: + raise ValueError(f"Environment variable {config[1:]} not found for config value {config}") + return env_value + return config + elif isinstance(config, dict): + return {k: cls.resolve_env_variables(v) for k, v in config.items()} + elif isinstance(config, list): + return [cls.resolve_env_variables(item) for item in config] + return config + + + +_app_config: AppConfig | None = None +_app_config_path: Path | None = None +_app_config_mtime: float | None = None +_app_config_is_custom = False +_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None) +_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", + default=()) + + +def _get_config_mtime(config_path: Path) -> float | None: + """Get the modification time of a config file if it exists.""" + try: + return config_path.stat().st_mtime + except OSError: + return None + + +def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig: + """Load config from disk and refresh cache metadata.""" + global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom + + resolved_path = AppConfig.resolve_config_path(config_path) + _app_config = AppConfig.from_file(str(resolved_path)) + _app_config_path = resolved_path + _app_config_mtime = _get_config_mtime(resolved_path) + _app_config_is_custom = False + return _app_config + + +def get_app_config() -> AppConfig: + """Get the DeerFlow config instance. + + Returns a cached singleton instance and automatically reloads it when the + underlying config file path or modification time changes. Use + `reload_app_config()` to force a reload, or `reset_app_config()` to clear + the cache. + """ + global _app_config, _app_config_path, _app_config_mtime + + runtime_override = _current_app_config.get() + if runtime_override is not None: + return runtime_override + + if _app_config is not None and _app_config_is_custom: + return _app_config + + resolved_path = AppConfig.resolve_config_path() + current_mtime = _get_config_mtime(resolved_path) + + should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime + if should_reload: + if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime: + logger.info( + "Config file has been modified (mtime: %s -> %s), reloading AppConfig", + _app_config_mtime, + current_mtime, + ) + _load_and_cache_app_config(str(resolved_path)) + return _app_config + + +def reload_app_config(config_path: str | None = None) -> AppConfig: + """Force reload from file and update the cache.""" + return _load_and_cache_app_config(config_path) + + +def reset_app_config() -> None: + """Clear the cache so the next `get_app_config()` reloads from file.""" + global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom + _app_config = None + _app_config_path = None + _app_config_mtime = None + _app_config_is_custom = False + + +def set_app_config(config: AppConfig) -> None: + """Inject a config instance directly, bypassing file loading (for testing).""" + global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom + _app_config = config + _app_config_path = None + _app_config_mtime = None + _app_config_is_custom = True + + +def peek_current_app_config() -> AppConfig | None: + """Return the runtime-scoped AppConfig override, if one is active.""" + return _current_app_config.get() + + +def push_current_app_config(config: AppConfig) -> None: + """Push a runtime-scoped AppConfig override for the current execution context.""" + stack = _current_app_config_stack.get() + _current_app_config_stack.set(stack + (_current_app_config.get(),)) + _current_app_config.set(config) + + +def pop_current_app_config() -> None: + """Pop the latest runtime-scoped AppConfig override for the current execution context.""" + stack = _current_app_config_stack.get() + if not stack: + _current_app_config.set(None) + return + previous = stack[-1] + _current_app_config_stack.set(stack[:-1]) + _current_app_config.set(previous) diff --git a/backend/packages/storage/store/config/storage_config.py b/backend/packages/storage/store/config/storage_config.py new file mode 100644 index 000000000..9f55a19aa --- /dev/null +++ b/backend/packages/storage/store/config/storage_config.py @@ -0,0 +1,72 @@ +"""Unified storage backend configuration for checkpointer and application data. + +SQLite: checkpointer → {sqlite_dir}/checkpoints.db, app → {sqlite_dir}/deerflow.db + (separate files to avoid write-lock contention) +Postgres: shared URL, independent connection pools per layer. + +Sensitive values use $VAR syntax resolved by AppConfig.resolve_env_variables() +before this config is instantiated. +""" + +from __future__ import annotations + +import os +from typing import Literal + +from pydantic import BaseModel, Field + + +def _strip_legacy_state_prefix(path: str) -> str: + """Keep old .deer-flow/* config values compatible with Paths.base_dir.""" + prefix = ".deer-flow/" + if path == ".deer-flow": + return "." + if path.startswith(prefix): + return path[len(prefix):] + return path + + +class StorageConfig(BaseModel): + driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field( + default="sqlite", + description="Storage driver for both checkpointer and application data. " + "'sqlite' for single-node deployment (default)," + "'postgres' for production multi-node deployment, " + "'mysql' for MySQL databases.", + ) + sqlite_dir: str = Field( + default=".deer-flow/data", + description="Directory for SQLite .db files (sqlite driver only).", + ) + username: str = Field(default="", description="db username ") + password: str = Field(default="", description="db password. Use $VAR syntax in config.yaml to read from .env.") + host: str = Field(default="localhost", description="db host.") + port: int = Field(default=5432, description="db port.") + db_name: str = Field(default="deerflow", description="db database name.") + database_url: str = Field(default="", description="Complete SQLAlchemy database URL. Takes precedence for non-SQLite drivers.") + sqlite_db_path: str = Field(default=".deer-flow/data", description="Directory for SQLite .db files (sqlite driver only).") + echo_sql: bool = Field(default=False, description="Log all SQL statements (debug only).") + pool_size: int = Field(default=5, description="Connection pool size per layer.") + + # -- Derived helpers (not user-configured) -- + + @property + def _resolved_sqlite_dir(self) -> str: + """Resolve sqlite_dir to an absolute path under DeerFlow's base dir.""" + from pathlib import Path + + path = Path(self.sqlite_dir) + if path.is_absolute(): + return str(path.resolve()) + + try: + from deerflow.config.paths import resolve_path + + return str(resolve_path(_strip_legacy_state_prefix(self.sqlite_dir))) + except ImportError: + return str(path.resolve()) + + @property + def sqlite_storage_path(self) -> str: + """SQLite file path for storage-owned app data and checkpointer.""" + return os.path.join(self._resolved_sqlite_dir, "deerflow.db") diff --git a/backend/packages/storage/store/persistence/__init__.py b/backend/packages/storage/store/persistence/__init__.py new file mode 100644 index 000000000..17caa12d2 --- /dev/null +++ b/backend/packages/storage/store/persistence/__init__.py @@ -0,0 +1,32 @@ +from store.persistence.base_model import ( + Base, + DataClassBase, + DateTimeMixin, + MappedBase, + TimeZone, + UniversalText, + id_key, +) + +from .factory import ( + create_persistence, + create_persistence_from_database_config, + create_persistence_from_storage_config, + storage_config_from_database_config, +) +from .types import AppPersistence + +__all__ = [ + "Base", + "DataClassBase", + "DateTimeMixin", + "MappedBase", + "TimeZone", + "UniversalText", + "id_key", + "create_persistence", + "create_persistence_from_database_config", + "create_persistence_from_storage_config", + "storage_config_from_database_config", + "AppPersistence", +] diff --git a/backend/packages/storage/store/persistence/base_model.py b/backend/packages/storage/store/persistence/base_model.py new file mode 100644 index 000000000..e60562020 --- /dev/null +++ b/backend/packages/storage/store/persistence/base_model.py @@ -0,0 +1,107 @@ +from datetime import datetime +from typing import Annotated + +from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator +from sqlalchemy.dialects.mysql import LONGTEXT +from sqlalchemy.ext.asyncio import AsyncAttrs +from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column + +from store.common import DataBaseType +from store.config.app_config import get_app_config +from store.utils import get_timezone + +timezone = get_timezone() +app_config = get_app_config() + +# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT) +_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger + +id_key = Annotated[ + int, + mapped_column( + _id_type, + primary_key=True, + unique=True, + index=True, + autoincrement=True, + sort_order=-999, + comment="Primary key ID", + ) +] + + +class UniversalText(TypeDecorator[str]): + """Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL).""" + + impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text + cache_ok = True + + def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001 + return value + + def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001 + return value + + +class TimeZone(TypeDecorator[datetime]): + """Timezone-aware datetime type compatible with PostgreSQL and MySQL.""" + + impl = DateTime(timezone=True) + cache_ok = True + + @property + def python_type(self) -> type[datetime]: + return datetime + + def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001 + if value is not None and value.utcoffset() != timezone.now().utcoffset(): + value = timezone.from_datetime(value) + return value + + def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001 + if value is not None and value.tzinfo is None: + value = value.replace(tzinfo=timezone.tz_info) + return value + + +class DateTimeMixin(MappedAsDataclass): + """Mixin that adds created_time / updated_time columns.""" + + created_time: Mapped[datetime] = mapped_column( + TimeZone, + init=False, + default_factory=timezone.now, + sort_order=999, + comment="Created at", + ) + updated_time: Mapped[datetime | None] = mapped_column( + TimeZone, + init=False, + onupdate=timezone.now, + sort_order=999, + comment="Updated at", + ) + + +class MappedBase(AsyncAttrs, DeclarativeBase): + """Async-capable declarative base for all ORM models.""" + + @declared_attr.directive + def __tablename__(self) -> str: + return self.__name__.lower() + + @declared_attr.directive + def __table_args__(self) -> dict: + return {"comment": self.__doc__ or ""} + + +class DataClassBase(MappedAsDataclass, MappedBase): + """Declarative base with native dataclass integration.""" + + __abstract__ = True + + +class Base(DataClassBase, DateTimeMixin): + """Declarative dataclass base with created_time / updated_time columns.""" + + __abstract__ = True diff --git a/backend/packages/storage/store/persistence/drivers/__init__.py b/backend/packages/storage/store/persistence/drivers/__init__.py new file mode 100644 index 000000000..7f9e45d04 --- /dev/null +++ b/backend/packages/storage/store/persistence/drivers/__init__.py @@ -0,0 +1,9 @@ +from .mysql import build_mysql_persistence +from .postgres import build_postgres_persistence +from .sqlite import build_sqlite_persistence + +__all__ = [ + "build_postgres_persistence", + "build_mysql_persistence", + "build_sqlite_persistence", +] diff --git a/backend/packages/storage/store/persistence/drivers/mysql.py b/backend/packages/storage/store/persistence/drivers/mysql.py new file mode 100644 index 000000000..f68b117b7 --- /dev/null +++ b/backend/packages/storage/store/persistence/drivers/mysql.py @@ -0,0 +1,75 @@ +from __future__ import annotations + +import json + +from sqlalchemy import URL +from sqlalchemy.engine import make_url +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from store.persistence import MappedBase +from store.persistence.shared import close_in_order +from store.persistence.types import AppPersistence + + +def _validate_mysql_driver(db_url: str) -> str: + url = make_url(db_url) + driver = url.get_driver_name() + + if driver not in {"aiomysql", "asyncmy"}: + raise ValueError( + "MySQL persistence requires async SQLAlchemy driver " + f"(aiomysql/asyncmy), got: {driver!r}" + ) + return driver + + +async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence: + _validate_mysql_driver(db_url) + + from langgraph.checkpoint.mysql.aio import AIOMySQLSaver + + import store.repositories.models # noqa: F401 + + engine = create_async_engine( + db_url, + echo=echo, + future=True, + pool_pre_ping=True, + pool_size=pool_size, + json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), + ) + + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + + saver_cm = AIOMySQLSaver.from_conn_string(db_url) + checkpointer = await saver_cm.__aenter__() + + async def setup() -> None: + # 1. LangGraph checkpoint tables / migrations + await checkpointer.setup() + + # 2. ORM business tables + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + + async def _close_saver() -> None: + await saver_cm.__aexit__(None, None, None) + + async def aclose() -> None: + await close_in_order( + engine.dispose, + _close_saver, + ) + + return AppPersistence( + checkpointer=checkpointer, + engine=engine, + session_factory=session_factory, + setup=setup, + aclose=aclose, + ) diff --git a/backend/packages/storage/store/persistence/drivers/postgres.py b/backend/packages/storage/store/persistence/drivers/postgres.py new file mode 100644 index 000000000..99b98a4ff --- /dev/null +++ b/backend/packages/storage/store/persistence/drivers/postgres.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import json + +from sqlalchemy import URL +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from store.persistence import MappedBase +from store.persistence.shared import close_in_order +from store.persistence.types import AppPersistence + + +async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence: + from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver + + import store.repositories.models # noqa: F401 + + engine = create_async_engine( + db_url, + echo=echo, + future=True, + pool_pre_ping=True, + pool_size=pool_size, + json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), + ) + + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + + saver_cm = AsyncPostgresSaver.from_conn_string(db_url) + checkpointer = await saver_cm.__aenter__() + + async def setup() -> None: + # 1. LangGraph checkpoint tables / migrations + await checkpointer.setup() + + # 2. ORM business tables + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + + async def _close_saver() -> None: + await saver_cm.__aexit__(None, None, None) + + async def aclose() -> None: + await close_in_order( + engine.dispose, + _close_saver, + ) + + return AppPersistence( + checkpointer=checkpointer, + engine=engine, + session_factory=session_factory, + setup=setup, + aclose=aclose, + ) diff --git a/backend/packages/storage/store/persistence/drivers/sqlite.py b/backend/packages/storage/store/persistence/drivers/sqlite.py new file mode 100644 index 000000000..be669baa5 --- /dev/null +++ b/backend/packages/storage/store/persistence/drivers/sqlite.py @@ -0,0 +1,69 @@ + +from __future__ import annotations + +import json + +from sqlalchemy import URL, event +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from store.persistence import MappedBase +from store.persistence.shared import close_in_order +from store.persistence.types import AppPersistence + + +async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence: + from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver + + import store.repositories.models # noqa: F401 + + engine = create_async_engine( + db_url, + echo=echo, + future=True, + json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False), + ) + + @event.listens_for(engine.sync_engine, "connect") + def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001 + cursor = dbapi_conn.cursor() + try: + cursor.execute("PRAGMA journal_mode=WAL;") + cursor.execute("PRAGMA synchronous=NORMAL;") + cursor.execute("PRAGMA foreign_keys=ON;") + finally: + cursor.close() + + session_factory = async_sessionmaker( + bind=engine, + class_=AsyncSession, + expire_on_commit=False, + autoflush=False, + ) + + saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database) + checkpointer = await saver_cm.__aenter__() + + async def setup() -> None: + # 1. LangGraph checkpoint tables + await checkpointer.setup() + + # 2. ORM business tables + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + + async def _close_saver() -> None: + await saver_cm.__aexit__(None, None, None) + + async def aclose() -> None: + await close_in_order( + engine.dispose, + _close_saver, + ) + + return AppPersistence( + checkpointer=checkpointer, + engine=engine, + session_factory=session_factory, + setup=setup, + aclose=aclose, + ) diff --git a/backend/packages/storage/store/persistence/factory.py b/backend/packages/storage/store/persistence/factory.py new file mode 100644 index 000000000..30de5230e --- /dev/null +++ b/backend/packages/storage/store/persistence/factory.py @@ -0,0 +1,121 @@ +from typing import Any + +from sqlalchemy import URL +from sqlalchemy.engine.url import make_url + +from store.common import DataBaseType +from store.config.app_config import get_app_config +from store.config.storage_config import StorageConfig +from store.persistence.types import AppPersistence + + +def storage_config_from_database_config(database_config: Any) -> StorageConfig: + """Convert the existing public DatabaseConfig shape to StorageConfig. + + Storage only owns durable database-backed persistence. The app bridge + should handle memory mode before calling into this package. + """ + backend = getattr(database_config, "backend", None) + if backend == "sqlite": + return StorageConfig( + driver="sqlite", + sqlite_dir=getattr(database_config, "sqlite_dir", ".deer-flow/data"), + echo_sql=getattr(database_config, "echo_sql", False), + pool_size=getattr(database_config, "pool_size", 5), + ) + + if backend == "postgres": + postgres_url = getattr(database_config, "postgres_url", "") + if not postgres_url: + raise ValueError("database.postgres_url is required when database.backend is 'postgres'") + parsed = make_url(postgres_url) + return StorageConfig( + driver="postgres", + database_url=postgres_url, + username=parsed.username or "", + password=parsed.password or "", + host=parsed.host or "localhost", + port=parsed.port or 5432, + db_name=parsed.database or "deerflow", + echo_sql=getattr(database_config, "echo_sql", False), + pool_size=getattr(database_config, "pool_size", 5), + ) + + raise ValueError(f"Unsupported database backend for storage persistence: {backend!r}") + + +def _create_database_url(storage_config: StorageConfig) -> URL: + """Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres).""" + + if storage_config.driver == DataBaseType.sqlite: + driver = "sqlite+aiosqlite" + elif storage_config.driver == DataBaseType.mysql: + driver = "mysql+aiomysql" + elif storage_config.driver in (DataBaseType.postgresql, "postgres"): + driver = "postgresql+asyncpg" + else: + raise ValueError(f"Unsupported database driver: {storage_config.driver}") + + if storage_config.driver == DataBaseType.sqlite: + import os + + db_path = storage_config.sqlite_storage_path + os.makedirs(os.path.dirname(db_path), exist_ok=True) + + url = URL.create( + drivername=driver, + database=db_path, + ) + elif storage_config.database_url: + url = make_url(storage_config.database_url) + if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql": + url = url.set(drivername="postgresql+asyncpg") + else: + url = URL.create( + drivername=driver, + username=storage_config.username, + password=storage_config.password, + host=storage_config.host, + port=storage_config.port, + database=storage_config.db_name or "deerflow", + ) + + return url + + +async def create_persistence_from_storage_config(storage_config: StorageConfig) -> AppPersistence: + from .drivers.mysql import build_mysql_persistence + from .drivers.postgres import build_postgres_persistence + from .drivers.sqlite import build_sqlite_persistence + + driver = storage_config.driver + db_url = _create_database_url(storage_config) + + if driver in ("postgres", "postgresql"): + return await build_postgres_persistence( + db_url, + echo=storage_config.echo_sql, + pool_size=storage_config.pool_size, + ) + + if driver == "mysql": + return await build_mysql_persistence( + db_url, + echo=storage_config.echo_sql, + pool_size=storage_config.pool_size, + ) + + if driver == "sqlite": + return await build_sqlite_persistence(db_url, echo=storage_config.echo_sql) + + raise ValueError(f"Unsupported database driver: {driver}") + + +async def create_persistence_from_database_config(database_config: Any) -> AppPersistence: + storage_config = storage_config_from_database_config(database_config) + return await create_persistence_from_storage_config(storage_config) + + +async def create_persistence() -> AppPersistence: + app_config = get_app_config() + return await create_persistence_from_storage_config(app_config.storage) diff --git a/backend/packages/storage/store/persistence/shared/__init__.py b/backend/packages/storage/store/persistence/shared/__init__.py new file mode 100644 index 000000000..95e6677ce --- /dev/null +++ b/backend/packages/storage/store/persistence/shared/__init__.py @@ -0,0 +1,3 @@ +from .close import close_in_order + +__all__ = ["close_in_order"] diff --git a/backend/packages/storage/store/persistence/shared/close.py b/backend/packages/storage/store/persistence/shared/close.py new file mode 100644 index 000000000..e912805cc --- /dev/null +++ b/backend/packages/storage/store/persistence/shared/close.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable + +AsyncCloser = Callable[[], Awaitable[None]] + + +async def close_in_order(*closers: AsyncCloser) -> None: + """ + Run async closers in order and raise the first error, if any. + + Notes + ----- + - Used to keep driver-specific close logic readable. + - We intentionally do not stop at first failure, so later resources + still get a chance to close. + """ + first_error: Exception | None = None + + for closer in closers: + try: + await closer() + except Exception as exc: + if first_error is None: + first_error = exc + + if first_error is not None: + raise first_error diff --git a/backend/packages/storage/store/persistence/types.py b/backend/packages/storage/store/persistence/types.py new file mode 100644 index 000000000..715cdb7a2 --- /dev/null +++ b/backend/packages/storage/store/persistence/types.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from langgraph.types import Checkpointer +from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker + +AsyncSetup = Callable[[], Awaitable[None]] +AsyncClose = Callable[[], Awaitable[None]] + + +@dataclass(slots=True) +class AppPersistence: + """ + Unified runtime persistence bundle. + """ + checkpointer: Checkpointer + engine: AsyncEngine + session_factory: async_sessionmaker[AsyncSession] + setup: AsyncSetup + aclose: AsyncClose diff --git a/backend/packages/storage/store/repositories/__init__.py b/backend/packages/storage/store/repositories/__init__.py new file mode 100644 index 000000000..4b3f078e7 --- /dev/null +++ b/backend/packages/storage/store/repositories/__init__.py @@ -0,0 +1,51 @@ +from store.repositories.contracts import ( + Feedback, + FeedbackAggregate, + FeedbackCreate, + FeedbackRepositoryProtocol, + Run, + RunCreate, + RunEvent, + RunEventCreate, + RunEventRepositoryProtocol, + RunRepositoryProtocol, + ThreadMeta, + ThreadMetaCreate, + ThreadMetaRepositoryProtocol, + User, + UserCreate, + UserNotFoundError, + UserRepositoryProtocol, +) +from store.repositories.factory import ( + build_feedback_repository, + build_run_event_repository, + build_run_repository, + build_thread_meta_repository, + build_user_repository, +) + +__all__ = [ + "Feedback", + "FeedbackAggregate", + "FeedbackCreate", + "FeedbackRepositoryProtocol", + "Run", + "RunCreate", + "RunEvent", + "RunEventCreate", + "RunEventRepositoryProtocol", + "RunRepositoryProtocol", + "ThreadMeta", + "ThreadMetaCreate", + "ThreadMetaRepositoryProtocol", + "User", + "UserCreate", + "UserNotFoundError", + "UserRepositoryProtocol", + "build_run_repository", + "build_run_event_repository", + "build_thread_meta_repository", + "build_feedback_repository", + "build_user_repository", +] diff --git a/backend/packages/storage/store/repositories/contracts/__init__.py b/backend/packages/storage/store/repositories/contracts/__init__.py new file mode 100644 index 000000000..4876c4a6b --- /dev/null +++ b/backend/packages/storage/store/repositories/contracts/__init__.py @@ -0,0 +1,47 @@ +from store.repositories.contracts.feedback import ( + Feedback, + FeedbackAggregate, + FeedbackCreate, + FeedbackRepositoryProtocol, +) +from store.repositories.contracts.run import ( + Run, + RunCreate, + RunRepositoryProtocol, +) +from store.repositories.contracts.run_event import ( + RunEvent, + RunEventCreate, + RunEventRepositoryProtocol, +) +from store.repositories.contracts.thread_meta import ( + ThreadMeta, + ThreadMetaCreate, + ThreadMetaRepositoryProtocol, +) +from store.repositories.contracts.user import ( + User, + UserCreate, + UserNotFoundError, + UserRepositoryProtocol, +) + +__all__ = [ + "Feedback", + "FeedbackAggregate", + "FeedbackCreate", + "FeedbackRepositoryProtocol", + "Run", + "RunCreate", + "RunEvent", + "RunEventCreate", + "RunEventRepositoryProtocol", + "RunRepositoryProtocol", + "ThreadMeta", + "ThreadMetaCreate", + "ThreadMetaRepositoryProtocol", + "User", + "UserCreate", + "UserNotFoundError", + "UserRepositoryProtocol", +] diff --git a/backend/packages/storage/store/repositories/contracts/feedback.py b/backend/packages/storage/store/repositories/contracts/feedback.py new file mode 100644 index 000000000..0ad0486ea --- /dev/null +++ b/backend/packages/storage/store/repositories/contracts/feedback.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Protocol, TypedDict + +from pydantic import BaseModel, ConfigDict + + +class FeedbackCreate(BaseModel): + model_config = ConfigDict(extra="forbid") + + feedback_id: str + run_id: str + thread_id: str + rating: int + user_id: str | None = None + message_id: str | None = None + comment: str | None = None + + +class Feedback(BaseModel): + model_config = ConfigDict(frozen=True) + + feedback_id: str + run_id: str + thread_id: str + rating: int + user_id: str | None + message_id: str | None + comment: str | None + created_time: datetime + + +class FeedbackAggregate(TypedDict): + run_id: str + total: int + positive: int + negative: int + + +class FeedbackRepositoryProtocol(Protocol): + async def create_feedback(self, data: FeedbackCreate) -> Feedback: ... + async def upsert_feedback(self, data: FeedbackCreate) -> Feedback: ... + async def get_feedback(self, feedback_id: str) -> Feedback | None: ... + async def list_feedback_by_run( + self, + run_id: str, + *, + thread_id: str | None = None, + user_id: str | None = None, + limit: int | None = None, + ) -> list[Feedback]: ... + async def list_feedback_by_thread( + self, + thread_id: str, + *, + user_id: str | None = None, + limit: int | None = None, + ) -> list[Feedback]: ... + async def delete_feedback(self, feedback_id: str) -> bool: ... + async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool: ... + async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate: ... diff --git a/backend/packages/storage/store/repositories/contracts/run.py b/backend/packages/storage/store/repositories/contracts/run.py new file mode 100644 index 000000000..a14995dc9 --- /dev/null +++ b/backend/packages/storage/store/repositories/contracts/run.py @@ -0,0 +1,85 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Protocol + +from pydantic import BaseModel, ConfigDict, Field + + +class RunCreate(BaseModel): + model_config = ConfigDict(extra="forbid") + + run_id: str + thread_id: str + assistant_id: str | None = None + user_id: str | None = None + status: str = "pending" + model_name: str | None = None + multitask_strategy: str = "reject" + error: str | None = None + follow_up_to_run_id: str | None = None + metadata: dict[str, Any] = Field(default_factory=dict) + kwargs: dict[str, Any] = Field(default_factory=dict) + created_time: datetime | None = None + + +class Run(BaseModel): + model_config = ConfigDict(frozen=True) + + run_id: str + thread_id: str + assistant_id: str | None + user_id: str | None + status: str + model_name: str | None + multitask_strategy: str + error: str | None + follow_up_to_run_id: str | None + metadata: dict[str, Any] + kwargs: dict[str, Any] + total_input_tokens: int + total_output_tokens: int + total_tokens: int + llm_call_count: int + lead_agent_tokens: int + subagent_tokens: int + middleware_tokens: int + message_count: int + first_human_message: str | None + last_ai_message: str | None + created_time: datetime + updated_time: datetime | None + + +class RunRepositoryProtocol(Protocol): + async def create_run(self, data: RunCreate) -> Run: ... + async def get_run(self, run_id: str) -> Run | None: ... + async def list_runs_by_thread( + self, + thread_id: str, + *, + user_id: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[Run]: ... + async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None: ... + async def delete_run(self, run_id: str) -> None: ... + async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]: ... + async def update_run_completion( + self, + run_id: str, + *, + status: str, + total_input_tokens: int = 0, + total_output_tokens: int = 0, + total_tokens: int = 0, + llm_call_count: int = 0, + lead_agent_tokens: int = 0, + subagent_tokens: int = 0, + middleware_tokens: int = 0, + message_count: int = 0, + first_human_message: str | None = None, + last_ai_message: str | None = None, + error: str | None = None, + ) -> None: ... + async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: ... diff --git a/backend/packages/storage/store/repositories/contracts/run_event.py b/backend/packages/storage/store/repositories/contracts/run_event.py new file mode 100644 index 000000000..1f0960337 --- /dev/null +++ b/backend/packages/storage/store/repositories/contracts/run_event.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Protocol + +from pydantic import BaseModel, ConfigDict, Field + + +class RunEventCreate(BaseModel): + model_config = ConfigDict(extra="forbid") + + thread_id: str + run_id: str + user_id: str | None = None + event_type: str + category: str + content: Any = "" + metadata: dict[str, Any] = Field(default_factory=dict) + created_at: datetime | None = None + + +class RunEvent(BaseModel): + model_config = ConfigDict(frozen=True) + + thread_id: str + run_id: str + user_id: str | None + event_type: str + category: str + content: Any + metadata: dict[str, Any] + seq: int + created_at: datetime + + +class RunEventRepositoryProtocol(Protocol): + async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ... + + async def list_messages( + self, + thread_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, + user_id: str | None = None, + ) -> list[RunEvent]: ... + + async def list_events( + self, + thread_id: str, + run_id: str, + *, + event_types: list[str] | None = None, + limit: int = 500, + user_id: str | None = None, + ) -> list[RunEvent]: ... + + async def list_messages_by_run( + self, + thread_id: str, + run_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, + user_id: str | None = None, + ) -> list[RunEvent]: ... + + async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int: ... + + async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int: ... + + async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int: ... diff --git a/backend/packages/storage/store/repositories/contracts/thread_meta.py b/backend/packages/storage/store/repositories/contracts/thread_meta.py new file mode 100644 index 000000000..de2d82b48 --- /dev/null +++ b/backend/packages/storage/store/repositories/contracts/thread_meta.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any, Protocol + +from pydantic import BaseModel, ConfigDict, Field + + +class ThreadMetaCreate(BaseModel): + model_config = ConfigDict(extra="forbid") + + thread_id: str + assistant_id: str | None = None + user_id: str | None = None + display_name: str | None = None + status: str = "idle" + metadata: dict[str, Any] = Field(default_factory=dict) + + +class ThreadMeta(BaseModel): + model_config = ConfigDict(frozen=True) + + thread_id: str + assistant_id: str | None + user_id: str | None + display_name: str | None + status: str + metadata: dict[str, Any] + created_time: datetime + updated_time: datetime | None + + +class ThreadMetaRepositoryProtocol(Protocol): + async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta: ... + + async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None: ... + + async def update_thread_meta( + self, + thread_id: str, + *, + display_name: str | None = None, + status: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: ... + + async def delete_thread(self, thread_id: str) -> None: ... + + async def search_threads( + self, + *, + metadata: dict[str, Any] | None = None, + status: str | None = None, + user_id: str | None = None, + assistant_id: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> list[ThreadMeta]: ... diff --git a/backend/packages/storage/store/repositories/contracts/user.py b/backend/packages/storage/store/repositories/contracts/user.py new file mode 100644 index 000000000..828ecf0fe --- /dev/null +++ b/backend/packages/storage/store/repositories/contracts/user.py @@ -0,0 +1,56 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Literal, Protocol + +from pydantic import BaseModel, ConfigDict + + +class UserNotFoundError(LookupError): + """Raised when an update targets a user row that no longer exists.""" + + +class UserCreate(BaseModel): + model_config = ConfigDict(extra="forbid") + + id: str + email: str + password_hash: str | None = None + system_role: Literal["admin", "user"] = "user" + created_at: datetime | None = None + oauth_provider: str | None = None + oauth_id: str | None = None + needs_setup: bool = False + token_version: int = 0 + + +class User(BaseModel): + model_config = ConfigDict(frozen=True) + + id: str + email: str + password_hash: str | None + system_role: Literal["admin", "user"] + created_at: datetime + oauth_provider: str | None + oauth_id: str | None + needs_setup: bool + token_version: int + + +class UserRepositoryProtocol(Protocol): + async def create_user(self, data: UserCreate) -> User: ... + + async def get_user_by_id(self, user_id: str) -> User | None: ... + + async def get_user_by_email(self, email: str) -> User | None: ... + + async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: ... + + async def get_first_admin(self) -> User | None: ... + + async def update_user(self, data: User) -> User: ... + + async def count_users(self) -> int: ... + + async def count_admin_users(self) -> int: ... diff --git a/backend/packages/storage/store/repositories/db/__init__.py b/backend/packages/storage/store/repositories/db/__init__.py new file mode 100644 index 000000000..bf3c30509 --- /dev/null +++ b/backend/packages/storage/store/repositories/db/__init__.py @@ -0,0 +1,13 @@ +from store.repositories.db.feedback import DbFeedbackRepository +from store.repositories.db.run import DbRunRepository +from store.repositories.db.run_event import DbRunEventRepository +from store.repositories.db.thread_meta import DbThreadMetaRepository +from store.repositories.db.user import DbUserRepository + +__all__ = [ + "DbFeedbackRepository", + "DbRunRepository", + "DbRunEventRepository", + "DbThreadMetaRepository", + "DbUserRepository", +] diff --git a/backend/packages/storage/store/repositories/db/feedback.py b/backend/packages/storage/store/repositories/db/feedback.py new file mode 100644 index 000000000..67a5c05fa --- /dev/null +++ b/backend/packages/storage/store/repositories/db/feedback.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from datetime import UTC, datetime + +from sqlalchemy import case, delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from store.repositories.contracts.feedback import Feedback, FeedbackAggregate, FeedbackCreate, FeedbackRepositoryProtocol +from store.repositories.models.feedback import Feedback as FeedbackModel + + +def _to_feedback(m: FeedbackModel) -> Feedback: + return Feedback( + feedback_id=m.feedback_id, + run_id=m.run_id, + thread_id=m.thread_id, + rating=m.rating, + user_id=m.user_id, + message_id=m.message_id, + comment=m.comment, + created_time=m.created_time, + ) + + +class DbFeedbackRepository(FeedbackRepositoryProtocol): + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def create_feedback(self, data: FeedbackCreate) -> Feedback: + if data.rating not in (1, -1): + raise ValueError(f"rating must be +1 or -1, got {data.rating}") + model = FeedbackModel( + feedback_id=data.feedback_id, + run_id=data.run_id, + thread_id=data.thread_id, + rating=data.rating, + user_id=data.user_id, + message_id=data.message_id, + comment=data.comment, + ) + self._session.add(model) + await self._session.flush() + await self._session.refresh(model) + return _to_feedback(model) + + async def upsert_feedback(self, data: FeedbackCreate) -> Feedback: + if data.rating not in (1, -1): + raise ValueError(f"rating must be +1 or -1, got {data.rating}") + + result = await self._session.execute( + select(FeedbackModel).where( + FeedbackModel.thread_id == data.thread_id, + FeedbackModel.run_id == data.run_id, + FeedbackModel.user_id == data.user_id, + ) + ) + model = result.scalar_one_or_none() + if model is None: + return await self.create_feedback(data) + + model.rating = data.rating + model.message_id = data.message_id + model.comment = data.comment + model.created_time = datetime.now(UTC) + await self._session.flush() + await self._session.refresh(model) + return _to_feedback(model) + + async def get_feedback(self, feedback_id: str) -> Feedback | None: + result = await self._session.execute( + select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id) + ) + model = result.scalar_one_or_none() + return _to_feedback(model) if model else None + + async def list_feedback_by_run( + self, + run_id: str, + *, + thread_id: str | None = None, + user_id: str | None = None, + limit: int | None = None, + ) -> list[Feedback]: + stmt = select(FeedbackModel).where(FeedbackModel.run_id == run_id) + if thread_id is not None: + stmt = stmt.where(FeedbackModel.thread_id == thread_id) + if user_id is not None: + stmt = stmt.where(FeedbackModel.user_id == user_id) + stmt = stmt.order_by(FeedbackModel.created_time.desc()) + if limit is not None: + stmt = stmt.limit(limit) + result = await self._session.execute(stmt) + return [_to_feedback(m) for m in result.scalars().all()] + + async def list_feedback_by_thread( + self, + thread_id: str, + *, + user_id: str | None = None, + limit: int | None = None, + ) -> list[Feedback]: + stmt = select(FeedbackModel).where(FeedbackModel.thread_id == thread_id) + if user_id is not None: + stmt = stmt.where(FeedbackModel.user_id == user_id) + stmt = stmt.order_by(FeedbackModel.created_time.desc()) + if limit is not None: + stmt = stmt.limit(limit) + result = await self._session.execute(stmt) + return [_to_feedback(m) for m in result.scalars().all()] + + async def delete_feedback(self, feedback_id: str) -> bool: + existing = await self.get_feedback(feedback_id) + if existing is None: + return False + await self._session.execute( + delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id) + ) + return True + + async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool: + stmt = select(FeedbackModel).where( + FeedbackModel.thread_id == thread_id, + FeedbackModel.run_id == run_id, + ) + if user_id is not None: + stmt = stmt.where(FeedbackModel.user_id == user_id) + result = await self._session.execute(stmt) + model = result.scalar_one_or_none() + if model is None: + return False + await self._session.delete(model) + return True + + async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate: + stmt = select( + func.count().label("total"), + func.coalesce(func.sum(case((FeedbackModel.rating == 1, 1), else_=0)), 0).label("positive"), + func.coalesce(func.sum(case((FeedbackModel.rating == -1, 1), else_=0)), 0).label("negative"), + ).where(FeedbackModel.thread_id == thread_id, FeedbackModel.run_id == run_id) + row = (await self._session.execute(stmt)).one() + return { + "run_id": run_id, + "total": int(row.total), + "positive": int(row.positive), + "negative": int(row.negative), + } diff --git a/backend/packages/storage/store/repositories/db/run.py b/backend/packages/storage/store/repositories/db/run.py new file mode 100644 index 000000000..93e6e1d95 --- /dev/null +++ b/backend/packages/storage/store/repositories/db/run.py @@ -0,0 +1,196 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from sqlalchemy import delete, func, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol +from store.repositories.models.run import Run as RunModel + + +def _to_run(m: RunModel) -> Run: + return Run( + run_id=m.run_id, + thread_id=m.thread_id, + assistant_id=m.assistant_id, + user_id=m.user_id, + status=m.status, + model_name=m.model_name, + multitask_strategy=m.multitask_strategy, + error=m.error, + follow_up_to_run_id=m.follow_up_to_run_id, + metadata=dict(m.meta or {}), + kwargs=dict(m.kwargs or {}), + total_input_tokens=m.total_input_tokens, + total_output_tokens=m.total_output_tokens, + total_tokens=m.total_tokens, + llm_call_count=m.llm_call_count, + lead_agent_tokens=m.lead_agent_tokens, + subagent_tokens=m.subagent_tokens, + middleware_tokens=m.middleware_tokens, + message_count=m.message_count, + first_human_message=m.first_human_message, + last_ai_message=m.last_ai_message, + created_time=m.created_time, + updated_time=m.updated_time, + ) + + +class DbRunRepository(RunRepositoryProtocol): + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def create_run(self, data: RunCreate) -> Run: + model = RunModel( + run_id=data.run_id, + thread_id=data.thread_id, + assistant_id=data.assistant_id, + user_id=data.user_id, + status=data.status, + model_name=data.model_name, + multitask_strategy=data.multitask_strategy, + error=data.error, + follow_up_to_run_id=data.follow_up_to_run_id, + meta=dict(data.metadata), + kwargs=dict(data.kwargs), + ) + if data.created_time is not None: + model.created_time = data.created_time + self._session.add(model) + await self._session.flush() + await self._session.refresh(model) + return _to_run(model) + + async def get_run(self, run_id: str) -> Run | None: + result = await self._session.execute( + select(RunModel).where(RunModel.run_id == run_id) + ) + model = result.scalar_one_or_none() + return _to_run(model) if model else None + + async def list_runs_by_thread( + self, + thread_id: str, + *, + user_id: str | None = None, + limit: int = 50, + offset: int = 0, + ) -> list[Run]: + stmt = select(RunModel).where(RunModel.thread_id == thread_id) + if user_id is not None: + stmt = stmt.where(RunModel.user_id == user_id) + stmt = stmt.order_by(RunModel.created_time.desc()).limit(limit).offset(offset) + result = await self._session.execute(stmt) + return [_to_run(m) for m in result.scalars().all()] + + async def update_run_status( + self, run_id: str, status: str, *, error: str | None = None + ) -> None: + values: dict = {"status": status} + if error is not None: + values["error"] = error + await self._session.execute( + update(RunModel).where(RunModel.run_id == run_id).values(**values) + ) + + async def delete_run(self, run_id: str) -> None: + await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id)) + + async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]: + if before is None: + before_dt = datetime.now().astimezone() + elif isinstance(before, datetime): + before_dt = before + else: + before_dt = datetime.fromisoformat(before) + + result = await self._session.execute( + select(RunModel) + .where(RunModel.status == "pending", RunModel.created_time <= before_dt) + .order_by(RunModel.created_time.asc()) + ) + return [_to_run(m) for m in result.scalars().all()] + + async def update_run_completion( + self, + run_id: str, + *, + status: str, + total_input_tokens: int = 0, + total_output_tokens: int = 0, + total_tokens: int = 0, + llm_call_count: int = 0, + lead_agent_tokens: int = 0, + subagent_tokens: int = 0, + middleware_tokens: int = 0, + message_count: int = 0, + first_human_message: str | None = None, + last_ai_message: str | None = None, + error: str | None = None, + ) -> None: + values = { + "status": status, + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "llm_call_count": llm_call_count, + "lead_agent_tokens": lead_agent_tokens, + "subagent_tokens": subagent_tokens, + "middleware_tokens": middleware_tokens, + "message_count": message_count, + } + if first_human_message is not None: + values["first_human_message"] = first_human_message[:2000] + if last_ai_message is not None: + values["last_ai_message"] = last_ai_message[:2000] + if error is not None: + values["error"] = error + await self._session.execute( + update(RunModel).where(RunModel.run_id == run_id).values(**values) + ) + + async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: + completed = RunModel.status.in_(("success", "error")) + stmt = ( + select( + func.coalesce(RunModel.model_name, "unknown").label("model"), + func.count().label("runs"), + func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"), + func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"), + func.coalesce(func.sum(RunModel.total_output_tokens), 0).label("total_output_tokens"), + func.coalesce(func.sum(RunModel.lead_agent_tokens), 0).label("lead_agent"), + func.coalesce(func.sum(RunModel.subagent_tokens), 0).label("subagent"), + func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"), + ) + .where(RunModel.thread_id == thread_id, completed) + .group_by(func.coalesce(RunModel.model_name, "unknown")) + ) + + rows = (await self._session.execute(stmt)).all() + total_tokens = total_input = total_output = total_runs = 0 + lead_agent = subagent = middleware = 0 + by_model: dict[str, dict] = {} + for row in rows: + by_model[row.model] = {"tokens": row.total_tokens, "runs": row.runs} + total_tokens += row.total_tokens + total_input += row.total_input_tokens + total_output += row.total_output_tokens + total_runs += row.runs + lead_agent += row.lead_agent + subagent += row.subagent + middleware += row.middleware + + return { + "total_tokens": total_tokens, + "total_input_tokens": total_input, + "total_output_tokens": total_output, + "total_runs": total_runs, + "by_model": by_model, + "by_caller": { + "lead_agent": lead_agent, + "subagent": subagent, + "middleware": middleware, + }, + } diff --git a/backend/packages/storage/store/repositories/db/run_event.py b/backend/packages/storage/store/repositories/db/run_event.py new file mode 100644 index 000000000..f18d17b77 --- /dev/null +++ b/backend/packages/storage/store/repositories/db/run_event.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +import json +from typing import Any + +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol +from store.repositories.models.run_event import RunEvent as RunEventModel + + +def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]: + if not isinstance(content, str): + next_metadata = {**metadata, "content_is_json": True} + if isinstance(content, dict): + next_metadata["content_is_dict"] = True + return json.dumps(content, default=str, ensure_ascii=False), next_metadata + return content, metadata + + +def _deserialize_content(content: str, metadata: dict[str, Any]) -> Any: + if not (metadata.get("content_is_json") or metadata.get("content_is_dict")): + return content + try: + return json.loads(content) + except json.JSONDecodeError: + return content + + +def _to_run_event(model: RunEventModel) -> RunEvent: + raw_metadata = dict(model.meta or {}) + metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"} + return RunEvent( + thread_id=model.thread_id, + run_id=model.run_id, + user_id=model.user_id, + event_type=model.event_type, + category=model.category, + content=_deserialize_content(model.content, raw_metadata), + metadata=metadata, + seq=model.seq, + created_at=model.created_at, + ) + + +class DbRunEventRepository(RunEventRepositoryProtocol): + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: + if not events: + return [] + + thread_ids = {event.thread_id for event in events} + seq_by_thread: dict[str, int] = {} + for thread_id in thread_ids: + max_seq = await self._session.scalar( + select(func.max(RunEventModel.seq)) + .where(RunEventModel.thread_id == thread_id) + .with_for_update() + ) + seq_by_thread[thread_id] = max_seq or 0 + + rows: list[RunEventModel] = [] + + for event in events: + seq_by_thread[event.thread_id] += 1 + content, metadata = _serialize_content(event.content, dict(event.metadata)) + row = RunEventModel( + thread_id=event.thread_id, + run_id=event.run_id, + user_id=event.user_id, + seq=seq_by_thread[event.thread_id], + event_type=event.event_type, + category=event.category, + content=content, + meta=metadata, + ) + if event.created_at is not None: + row.created_at = event.created_at + self._session.add(row) + rows.append(row) + + await self._session.flush() + return [_to_run_event(row) for row in rows] + + async def list_messages( + self, + thread_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, + user_id: str | None = None, + ) -> list[RunEvent]: + stmt = select(RunEventModel).where( + RunEventModel.thread_id == thread_id, + RunEventModel.category == "message", + ) + if user_id is not None: + stmt = stmt.where(RunEventModel.user_id == user_id) + if before_seq is not None: + stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit) + result = await self._session.execute(stmt) + return list(reversed([_to_run_event(row) for row in result.scalars().all()])) + if after_seq is not None: + stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit) + result = await self._session.execute(stmt) + return [_to_run_event(row) for row in result.scalars().all()] + + stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit) + result = await self._session.execute(stmt) + return list(reversed([_to_run_event(row) for row in result.scalars().all()])) + + async def list_events( + self, + thread_id: str, + run_id: str, + *, + event_types: list[str] | None = None, + limit: int = 500, + user_id: str | None = None, + ) -> list[RunEvent]: + stmt = select(RunEventModel).where( + RunEventModel.thread_id == thread_id, + RunEventModel.run_id == run_id, + ) + if user_id is not None: + stmt = stmt.where(RunEventModel.user_id == user_id) + if event_types is not None: + stmt = stmt.where(RunEventModel.event_type.in_(event_types)) + stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit) + result = await self._session.execute(stmt) + return [_to_run_event(row) for row in result.scalars().all()] + + async def list_messages_by_run( + self, + thread_id: str, + run_id: str, + *, + limit: int = 50, + before_seq: int | None = None, + after_seq: int | None = None, + user_id: str | None = None, + ) -> list[RunEvent]: + stmt = ( + select(RunEventModel) + .where( + RunEventModel.thread_id == thread_id, + RunEventModel.run_id == run_id, + RunEventModel.category == "message", + ) + ) + if user_id is not None: + stmt = stmt.where(RunEventModel.user_id == user_id) + if before_seq is not None: + stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit) + result = await self._session.execute(stmt) + return list(reversed([_to_run_event(row) for row in result.scalars().all()])) + if after_seq is not None: + stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit) + result = await self._session.execute(stmt) + return [_to_run_event(row) for row in result.scalars().all()] + + stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit) + result = await self._session.execute(stmt) + return list(reversed([_to_run_event(row) for row in result.scalars().all()])) + + async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int: + stmt = ( + select(func.count()) + .select_from(RunEventModel) + .where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message") + ) + if user_id is not None: + stmt = stmt.where(RunEventModel.user_id == user_id) + count = await self._session.scalar(stmt) + return int(count or 0) + + async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int: + conditions = [RunEventModel.thread_id == thread_id] + if user_id is not None: + conditions.append(RunEventModel.user_id == user_id) + count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions)) + await self._session.execute(delete(RunEventModel).where(*conditions)) + return int(count or 0) + + async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int: + conditions = [RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id] + if user_id is not None: + conditions.append(RunEventModel.user_id == user_id) + count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions)) + await self._session.execute(delete(RunEventModel).where(*conditions)) + return int(count or 0) diff --git a/backend/packages/storage/store/repositories/db/thread_meta.py b/backend/packages/storage/store/repositories/db/thread_meta.py new file mode 100644 index 000000000..f9fcf74d4 --- /dev/null +++ b/backend/packages/storage/store/repositories/db/thread_meta.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +from typing import Any + +from sqlalchemy import delete, select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol +from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel + + +def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta: + return ThreadMeta( + thread_id=m.thread_id, + assistant_id=m.assistant_id, + user_id=m.user_id, + display_name=m.display_name, + status=m.status, + metadata=dict(m.meta or {}), + created_time=m.created_time, + updated_time=m.updated_time, + ) + + +class DbThreadMetaRepository(ThreadMetaRepositoryProtocol): + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta: + model = ThreadMetaModel( + thread_id=data.thread_id, + assistant_id=data.assistant_id, + user_id=data.user_id, + display_name=data.display_name, + status=data.status, + meta=dict(data.metadata), + ) + self._session.add(model) + await self._session.flush() + await self._session.refresh(model) + return _to_thread_meta(model) + + async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None: + result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id)) + model = result.scalar_one_or_none() + return _to_thread_meta(model) if model else None + + async def update_thread_meta( + self, + thread_id: str, + *, + display_name: str | None = None, + status: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + values: dict = {} + if display_name is not None: + values["display_name"] = display_name + if status is not None: + values["status"] = status + if metadata is not None: + values["meta"] = dict(metadata) + if not values: + return + await self._session.execute( + update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values)) + + async def delete_thread(self, thread_id: str) -> None: + await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id)) + + async def search_threads( + self, + *, + metadata: dict[str, Any] | None = None, + status: str | None = None, + user_id: str | None = None, + assistant_id: str | None = None, + limit: int = 100, + offset: int = 0, + ) -> list[ThreadMeta]: + stmt = select(ThreadMetaModel) + + if status is not None: + stmt = stmt.where(ThreadMetaModel.status == status) + if user_id is not None: + stmt = stmt.where(ThreadMetaModel.user_id == user_id) + if assistant_id is not None: + stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id) + if metadata: + for key, value in metadata.items(): + stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value)) + + stmt = stmt.order_by(ThreadMetaModel.created_time.desc()) + stmt = stmt.limit(limit).offset(offset) + + result = await self._session.execute(stmt) + return [_to_thread_meta(m) for m in result.scalars().all()] diff --git a/backend/packages/storage/store/repositories/db/user.py b/backend/packages/storage/store/repositories/db/user.py new file mode 100644 index 000000000..1b2417420 --- /dev/null +++ b/backend/packages/storage/store/repositories/db/user.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +from sqlalchemy import func, select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from store.repositories.contracts.user import User, UserCreate, UserNotFoundError, UserRepositoryProtocol +from store.repositories.models.user import User as UserModel + + +def _to_user(model: UserModel) -> User: + return User( + id=model.id, + email=model.email, + password_hash=model.password_hash, + system_role=model.system_role, # type: ignore[arg-type] + created_at=model.created_at, + oauth_provider=model.oauth_provider, + oauth_id=model.oauth_id, + needs_setup=model.needs_setup, + token_version=model.token_version, + ) + + +class DbUserRepository(UserRepositoryProtocol): + def __init__(self, session: AsyncSession) -> None: + self._session = session + + async def create_user(self, data: UserCreate) -> User: + model = UserModel( + id=data.id, + email=data.email, + system_role=data.system_role, + password_hash=data.password_hash, + oauth_provider=data.oauth_provider, + oauth_id=data.oauth_id, + needs_setup=data.needs_setup, + token_version=data.token_version, + ) + if data.created_at is not None: + model.created_at = data.created_at + self._session.add(model) + try: + await self._session.flush() + except IntegrityError as exc: + await self._session.rollback() + raise ValueError(f"Email already registered: {data.email}") from exc + await self._session.refresh(model) + return _to_user(model) + + async def get_user_by_id(self, user_id: str) -> User | None: + model = await self._session.get(UserModel, user_id) + return _to_user(model) if model is not None else None + + async def get_user_by_email(self, email: str) -> User | None: + result = await self._session.execute(select(UserModel).where(UserModel.email == email)) + model = result.scalar_one_or_none() + return _to_user(model) if model is not None else None + + async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: + result = await self._session.execute( + select(UserModel).where( + UserModel.oauth_provider == provider, + UserModel.oauth_id == oauth_id, + ) + ) + model = result.scalar_one_or_none() + return _to_user(model) if model is not None else None + + async def get_first_admin(self) -> User | None: + result = await self._session.execute(select(UserModel).where(UserModel.system_role == "admin").limit(1)) + model = result.scalar_one_or_none() + return _to_user(model) if model is not None else None + + async def update_user(self, data: User) -> User: + model = await self._session.get(UserModel, data.id) + if model is None: + raise UserNotFoundError(f"User {data.id} no longer exists") + + model.email = data.email + model.password_hash = data.password_hash + model.system_role = data.system_role + model.oauth_provider = data.oauth_provider + model.oauth_id = data.oauth_id + model.needs_setup = data.needs_setup + model.token_version = data.token_version + + await self._session.flush() + await self._session.refresh(model) + return _to_user(model) + + async def count_users(self) -> int: + count = await self._session.scalar(select(func.count()).select_from(UserModel)) + return int(count or 0) + + async def count_admin_users(self) -> int: + count = await self._session.scalar(select(func.count()).select_from(UserModel).where(UserModel.system_role == "admin")) + return int(count or 0) diff --git a/backend/packages/storage/store/repositories/factory.py b/backend/packages/storage/store/repositories/factory.py new file mode 100644 index 000000000..b99d4e0b3 --- /dev/null +++ b/backend/packages/storage/store/repositories/factory.py @@ -0,0 +1,36 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from store.repositories import ( + FeedbackRepositoryProtocol, + RunEventRepositoryProtocol, + RunRepositoryProtocol, + ThreadMetaRepositoryProtocol, + UserRepositoryProtocol, +) +from store.repositories.db import ( + DbFeedbackRepository, + DbRunEventRepository, + DbRunRepository, + DbThreadMetaRepository, + DbUserRepository, +) + + +def build_thread_meta_repository(session: AsyncSession) -> ThreadMetaRepositoryProtocol: + return DbThreadMetaRepository(session) + + +def build_run_repository(session: AsyncSession) -> RunRepositoryProtocol: + return DbRunRepository(session) + + +def build_feedback_repository(session: AsyncSession) -> FeedbackRepositoryProtocol: + return DbFeedbackRepository(session) + + +def build_run_event_repository(session: AsyncSession) -> RunEventRepositoryProtocol: + return DbRunEventRepository(session) + + +def build_user_repository(session: AsyncSession) -> UserRepositoryProtocol: + return DbUserRepository(session) diff --git a/backend/packages/storage/store/repositories/models/__init__.py b/backend/packages/storage/store/repositories/models/__init__.py new file mode 100644 index 000000000..42c1f6d8d --- /dev/null +++ b/backend/packages/storage/store/repositories/models/__init__.py @@ -0,0 +1,7 @@ +from store.repositories.models.feedback import Feedback +from store.repositories.models.run import Run +from store.repositories.models.run_event import RunEvent +from store.repositories.models.thread_meta import ThreadMeta +from store.repositories.models.user import User + +__all__ = ["Feedback", "Run", "RunEvent", "ThreadMeta", "User"] diff --git a/backend/packages/storage/store/repositories/models/feedback.py b/backend/packages/storage/store/repositories/models/feedback.py new file mode 100644 index 000000000..581a91bbc --- /dev/null +++ b/backend/packages/storage/store/repositories/models/feedback.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy import Integer, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from store.persistence.base_model import DataClassBase, TimeZone, UniversalText +from store.utils import get_timezone + +_tz = get_timezone() + + +class Feedback(DataClassBase): + """Feedback table (create-only, no updated_time).""" + + __tablename__ = "feedback" + __table_args__ = ( + UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"), + {"comment": "Feedback table."}, + ) + + feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True) + run_id: Mapped[str] = mapped_column(String(64), index=True) + thread_id: Mapped[str] = mapped_column(String(64), index=True) + rating: Mapped[int] = mapped_column(Integer) + + user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True) + message_id: Mapped[str | None] = mapped_column(String(64), default=None) + comment: Mapped[str | None] = mapped_column(UniversalText, default=None) + + created_time: Mapped[datetime] = mapped_column( + "created_at", + TimeZone, + init=False, + default_factory=_tz.now, + sort_order=999, + comment="Created at", + ) diff --git a/backend/packages/storage/store/repositories/models/run.py b/backend/packages/storage/store/repositories/models/run.py new file mode 100644 index 000000000..dd0f93b88 --- /dev/null +++ b/backend/packages/storage/store/repositories/models/run.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from sqlalchemy import JSON, Index, Integer, String +from sqlalchemy.orm import Mapped, mapped_column + +from store.persistence.base_model import DataClassBase, TimeZone, UniversalText +from store.utils import get_timezone + +_tz = get_timezone() + + +class Run(DataClassBase): + """Run metadata table.""" + + __tablename__ = "runs" + __table_args__ = ( + Index("ix_runs_thread_status", "thread_id", "status"), + {"comment": "Run metadata table."}, + ) + + run_id: Mapped[str] = mapped_column(String(64), primary_key=True) + thread_id: Mapped[str] = mapped_column(String(64), index=True) + + assistant_id: Mapped[str | None] = mapped_column(String(128), default=None) + user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True) + status: Mapped[str] = mapped_column(String(20), default="pending", index=True) + model_name: Mapped[str | None] = mapped_column(String(128), default=None) + multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject") + error: Mapped[str | None] = mapped_column(UniversalText, default=None) + follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64), default=None) + + meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict) + kwargs: Mapped[dict[str, Any]] = mapped_column("kwargs_json", JSON, default_factory=dict) + + total_input_tokens: Mapped[int] = mapped_column(Integer, default=0) + total_output_tokens: Mapped[int] = mapped_column(Integer, default=0) + total_tokens: Mapped[int] = mapped_column(Integer, default=0) + llm_call_count: Mapped[int] = mapped_column(Integer, default=0) + lead_agent_tokens: Mapped[int] = mapped_column(Integer, default=0) + subagent_tokens: Mapped[int] = mapped_column(Integer, default=0) + middleware_tokens: Mapped[int] = mapped_column(Integer, default=0) + + message_count: Mapped[int] = mapped_column(Integer, default=0) + first_human_message: Mapped[str | None] = mapped_column(UniversalText, default=None) + last_ai_message: Mapped[str | None] = mapped_column(UniversalText, default=None) + + created_time: Mapped[datetime] = mapped_column( + "created_at", + TimeZone, + init=False, + default_factory=_tz.now, + sort_order=999, + comment="Created at", + ) + updated_time: Mapped[datetime | None] = mapped_column( + "updated_at", + TimeZone, + init=False, + default=None, + onupdate=_tz.now, + sort_order=999, + comment="Updated at", + ) diff --git a/backend/packages/storage/store/repositories/models/run_event.py b/backend/packages/storage/store/repositories/models/run_event.py new file mode 100644 index 000000000..b07665d1c --- /dev/null +++ b/backend/packages/storage/store/repositories/models/run_event.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from sqlalchemy import JSON, Index, Integer, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column + +from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key +from store.utils import get_timezone + +_tz = get_timezone() + + +class RunEvent(DataClassBase): + """Run event table.""" + + __tablename__ = "run_events" + __table_args__ = ( + UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"), + Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"), + Index("ix_events_run", "thread_id", "run_id", "seq"), + {"comment": "Run event table."}, + ) + + id: Mapped[id_key] = mapped_column(init=False) + + thread_id: Mapped[str] = mapped_column(String(64), index=True) + run_id: Mapped[str] = mapped_column(String(64), index=True) + event_type: Mapped[str] = mapped_column(String(32), index=True) + category: Mapped[str] = mapped_column(String(16), index=True) + + user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True) + seq: Mapped[int] = mapped_column(Integer, default=0, index=True) + content: Mapped[str] = mapped_column(UniversalText, default="") + meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict) + created_at: Mapped[datetime] = mapped_column( + TimeZone, + init=False, + default_factory=_tz.now, + sort_order=999, + comment="Event timestamp", + ) diff --git a/backend/packages/storage/store/repositories/models/thread_meta.py b/backend/packages/storage/store/repositories/models/thread_meta.py new file mode 100644 index 000000000..ce3e70f27 --- /dev/null +++ b/backend/packages/storage/store/repositories/models/thread_meta.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from sqlalchemy import JSON, String +from sqlalchemy.orm import Mapped, mapped_column + +from store.persistence.base_model import DataClassBase, TimeZone +from store.utils import get_timezone + +_tz = get_timezone() + + +class ThreadMeta(DataClassBase): + """Thread metadata table.""" + + __tablename__ = "threads_meta" + __table_args__ = {"comment": "Thread metadata table."} + + thread_id: Mapped[str] = mapped_column(String(64), primary_key=True) + + assistant_id: Mapped[str | None] = mapped_column(String(128), default=None, index=True) + user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True) + display_name: Mapped[str | None] = mapped_column(String(256), default=None) + status: Mapped[str] = mapped_column(String(20), default="idle", index=True) + + meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict) + + created_time: Mapped[datetime] = mapped_column( + "created_at", + TimeZone, + init=False, + default_factory=_tz.now, + sort_order=999, + comment="Created at", + ) + updated_time: Mapped[datetime | None] = mapped_column( + "updated_at", + TimeZone, + init=False, + default=None, + onupdate=_tz.now, + sort_order=999, + comment="Updated at", + ) diff --git a/backend/packages/storage/store/repositories/models/user.py b/backend/packages/storage/store/repositories/models/user.py new file mode 100644 index 000000000..a017ec47e --- /dev/null +++ b/backend/packages/storage/store/repositories/models/user.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from datetime import datetime + +from sqlalchemy import Boolean, Index, String, text +from sqlalchemy.orm import Mapped, mapped_column + +from store.persistence.base_model import DataClassBase, TimeZone +from store.utils import get_timezone + +_tz = get_timezone() + + +class User(DataClassBase): + """User account table.""" + + __tablename__ = "users" + __table_args__ = ( + Index( + "idx_users_oauth_identity", + "oauth_provider", + "oauth_id", + unique=True, + sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"), + ), + {"comment": "User account table."}, + ) + + id: Mapped[str] = mapped_column(String(36), primary_key=True) + email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True) + system_role: Mapped[str] = mapped_column(String(16), default="user") + + password_hash: Mapped[str | None] = mapped_column(String(128), default=None) + oauth_provider: Mapped[str | None] = mapped_column(String(32), default=None) + oauth_id: Mapped[str | None] = mapped_column(String(128), default=None) + needs_setup: Mapped[bool] = mapped_column(Boolean, default=False) + token_version: Mapped[int] = mapped_column(default=0) + + created_at: Mapped[datetime] = mapped_column( + TimeZone, + init=False, + default_factory=_tz.now, + sort_order=999, + comment="Created at", + ) diff --git a/backend/packages/storage/store/utils/__init__.py b/backend/packages/storage/store/utils/__init__.py new file mode 100644 index 000000000..1ee6c9df9 --- /dev/null +++ b/backend/packages/storage/store/utils/__init__.py @@ -0,0 +1,3 @@ +from .timezone import get_timezone + +__all__ = ["get_timezone"] \ No newline at end of file diff --git a/backend/packages/storage/store/utils/timezone.py b/backend/packages/storage/store/utils/timezone.py new file mode 100644 index 000000000..ac9ebcf68 --- /dev/null +++ b/backend/packages/storage/store/utils/timezone.py @@ -0,0 +1,51 @@ +import zoneinfo +from datetime import UTC, datetime + +from store.config.app_config import get_app_config + +# IANA identifiers that map to UTC — see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones +_UTC_IDENTIFIERS = frozenset({"Etc/UCT", "Etc/Universal", "Etc/UTC", "Etc/Zulu", "UCT", "Universal", "UTC", "Zulu"}) + + +class TimeZone: + def __init__(self) -> None: + app_config = get_app_config() + if app_config.timezone in _UTC_IDENTIFIERS: + self.tz_info = UTC + else: + self.tz_info = zoneinfo.ZoneInfo(app_config.timezone) + + def now(self) -> datetime: + """Return the current time in the configured timezone.""" + return datetime.now(self.tz_info) + + def from_datetime(self, t: datetime) -> datetime: + """Convert a datetime to the configured timezone.""" + return t.astimezone(self.tz_info) + + def from_str(self, t_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime: + """Parse a time string and attach the configured timezone.""" + return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info) + + @staticmethod + def to_str(t: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str: + """Format a datetime to string.""" + return t.strftime(format_str) + + @staticmethod + def to_utc(t: datetime | int) -> datetime: + """Convert a datetime or Unix timestamp to UTC.""" + if isinstance(t, datetime): + return t.astimezone(UTC) + return datetime.fromtimestamp(t, tz=UTC) + + +_timezone = None + + +def get_timezone() -> TimeZone: + """Return the global TimeZone singleton (lazy-initialized).""" + global _timezone + if _timezone is None: + _timezone = TimeZone() + return _timezone diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 6d2edb0bb..6722b255f 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -6,6 +6,7 @@ readme = "README.md" requires-python = ">=3.12" dependencies = [ "deerflow-harness", + "deerflow-storage", "fastapi>=0.115.0", "httpx>=0.28.0", "python-multipart>=0.0.27", @@ -24,7 +25,7 @@ dependencies = [ ] [project.optional-dependencies] -postgres = ["deerflow-harness[postgres]"] +postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"] [dependency-groups] dev = [ @@ -43,7 +44,8 @@ markers = [ index-url = "https://pypi.org/simple" [tool.uv.workspace] -members = ["packages/harness"] +members = ["packages/harness", "packages/storage"] [tool.uv.sources] deerflow-harness = { workspace = true } +deerflow-storage = { workspace = true } diff --git a/backend/tests/test_storage_persistence_config.py b/backend/tests/test_storage_persistence_config.py new file mode 100644 index 000000000..17b290296 --- /dev/null +++ b/backend/tests/test_storage_persistence_config.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import os +from pathlib import Path +from types import SimpleNamespace + +import pytest + +os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml")) + +from store.config.storage_config import StorageConfig +from store.persistence.factory import _create_database_url, storage_config_from_database_config + + +def test_database_sqlite_config_maps_to_storage_config(tmp_path): + database = SimpleNamespace( + backend="sqlite", + sqlite_dir=str(tmp_path), + echo_sql=True, + pool_size=9, + ) + + storage = storage_config_from_database_config(database) + + assert storage == StorageConfig( + driver="sqlite", + sqlite_dir=str(tmp_path), + echo_sql=True, + pool_size=9, + ) + assert storage.sqlite_storage_path == str(tmp_path / "deerflow.db") + + +def test_database_memory_config_is_not_a_storage_backend(): + database = SimpleNamespace(backend="memory") + + with pytest.raises(ValueError, match="Unsupported database backend"): + storage_config_from_database_config(database) + + +def test_database_postgres_config_preserves_url_and_pool_options(): + database = SimpleNamespace( + backend="postgres", + postgres_url="postgresql://user:pass@db.example:5544/deerflow", + echo_sql=True, + pool_size=11, + ) + + storage = storage_config_from_database_config(database) + url = _create_database_url(storage) + + assert storage.driver == "postgres" + assert storage.database_url == "postgresql://user:pass@db.example:5544/deerflow" + assert storage.username == "user" + assert storage.password == "pass" + assert storage.host == "db.example" + assert storage.port == 5544 + assert storage.db_name == "deerflow" + assert storage.echo_sql is True + assert storage.pool_size == 11 + assert url.drivername == "postgresql+asyncpg" + assert url.database == "deerflow" + + +def test_database_postgres_requires_url(): + database = SimpleNamespace(backend="postgres", postgres_url="") + + with pytest.raises(ValueError, match="database.postgres_url is required"): + storage_config_from_database_config(database) + + +def test_unsupported_database_backend_rejected(): + database = SimpleNamespace(backend="oracle") + + with pytest.raises(ValueError, match="Unsupported database backend"): + storage_config_from_database_config(database) diff --git a/backend/tests/test_storage_persistence_sqlite.py b/backend/tests/test_storage_persistence_sqlite.py new file mode 100644 index 000000000..2d427b31b --- /dev/null +++ b/backend/tests/test_storage_persistence_sqlite.py @@ -0,0 +1,58 @@ +from __future__ import annotations + +import asyncio +import os +from pathlib import Path +from types import SimpleNamespace +from uuid import uuid4 + +os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml")) + +from sqlalchemy import inspect +from store.persistence import create_persistence_from_database_config +from store.repositories import UserCreate, build_user_repository + + +def test_sqlite_persistence_from_database_config_creates_storage_tables(tmp_path): + async def run() -> None: + persistence = await create_persistence_from_database_config( + SimpleNamespace( + backend="sqlite", + sqlite_dir=str(tmp_path), + echo_sql=False, + pool_size=5, + ) + ) + assert persistence is not None + try: + await persistence.setup() + + async with persistence.engine.connect() as conn: + tables = await conn.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names())) + + assert { + "users", + "runs", + "run_events", + "threads_meta", + "feedback", + }.issubset(tables) + + async with persistence.session_factory() as session: + repo = build_user_repository(session) + user = await repo.create_user( + UserCreate( + id=str(uuid4()), + email="storage-user@example.com", + password_hash="hash", + ) + ) + await session.commit() + + async with persistence.session_factory() as session: + repo = build_user_repository(session) + assert await repo.get_user_by_id(user.id) == user + finally: + await persistence.aclose() + + asyncio.run(run()) diff --git a/backend/tests/test_storage_repositories.py b/backend/tests/test_storage_repositories.py new file mode 100644 index 000000000..d04563e55 --- /dev/null +++ b/backend/tests/test_storage_repositories.py @@ -0,0 +1,312 @@ +from __future__ import annotations + +import os +from datetime import UTC, datetime, timedelta +from pathlib import Path +from types import SimpleNamespace + +import pytest + +os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml")) + +from store.persistence import create_persistence_from_database_config +from store.repositories import ( + FeedbackCreate, + RunCreate, + RunEventCreate, + ThreadMetaCreate, + build_feedback_repository, + build_run_event_repository, + build_run_repository, + build_thread_meta_repository, +) + + +async def _make_persistence(tmp_path): + persistence = await create_persistence_from_database_config( + SimpleNamespace( + backend="sqlite", + sqlite_dir=str(tmp_path), + echo_sql=False, + pool_size=5, + ) + ) + await persistence.setup() + return persistence + + +@pytest.mark.anyio +async def test_storage_run_repository_filters_and_aggregates(tmp_path): + persistence = await _make_persistence(tmp_path) + old = datetime.now(UTC) - timedelta(hours=1) + newer = datetime.now(UTC) + try: + async with persistence.session_factory() as session: + repo = build_run_repository(session) + await repo.create_run( + RunCreate( + run_id="run-old", + thread_id="thread-1", + user_id="alice", + status="pending", + model_name="model-a", + metadata={"kind": "draft"}, + kwargs={"temperature": 0.2}, + created_time=old, + ) + ) + await repo.create_run( + RunCreate( + run_id="run-new", + thread_id="thread-1", + user_id="bob", + status="running", + model_name="model-b", + error="queued", + created_time=newer, + ) + ) + await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running")) + await repo.update_run_completion( + "run-old", + status="success", + total_input_tokens=7, + total_output_tokens=3, + total_tokens=10, + llm_call_count=1, + lead_agent_tokens=8, + subagent_tokens=2, + first_human_message="hello", + last_ai_message="world", + ) + await repo.update_run_completion( + "run-new", + status="error", + total_tokens=5, + middleware_tokens=5, + error="failed", + ) + await session.commit() + + async with persistence.session_factory() as session: + repo = build_run_repository(session) + fetched = await repo.get_run("run-old") + assert fetched is not None + assert fetched.metadata == {"kind": "draft"} + assert fetched.kwargs == {"temperature": 0.2} + assert fetched.first_human_message == "hello" + assert fetched.last_ai_message == "world" + + all_thread_runs = await repo.list_runs_by_thread("thread-1") + assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"] + alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice") + assert [run.run_id for run in alice_runs] == ["run-old"] + + pending = await repo.list_pending(before=datetime.now(UTC).isoformat()) + assert [run.run_id for run in pending] == [] + + agg = await repo.aggregate_tokens_by_thread("thread-1") + assert agg["total_tokens"] == 15 + assert agg["total_input_tokens"] == 7 + assert agg["total_output_tokens"] == 3 + assert agg["total_runs"] == 2 + assert agg["by_model"] == { + "model-a": {"tokens": 10, "runs": 1}, + "model-b": {"tokens": 5, "runs": 1}, + } + assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5} + finally: + await persistence.aclose() + + +@pytest.mark.anyio +async def test_storage_thread_meta_repository_search_update_delete(tmp_path): + persistence = await _make_persistence(tmp_path) + try: + async with persistence.session_factory() as session: + repo = build_thread_meta_repository(session) + await repo.create_thread_meta( + ThreadMetaCreate( + thread_id="thread-1", + assistant_id="agent-a", + user_id="alice", + display_name="Initial", + status="idle", + metadata={"topic": "finance", "region": "cn"}, + ) + ) + await repo.create_thread_meta( + ThreadMetaCreate( + thread_id="thread-2", + assistant_id="agent-b", + user_id="bob", + status="running", + metadata={"topic": "legal"}, + ) + ) + await repo.update_thread_meta( + "thread-1", + display_name="Updated", + status="running", + metadata={"topic": "finance", "region": "us"}, + ) + await session.commit() + + async with persistence.session_factory() as session: + repo = build_thread_meta_repository(session) + fetched = await repo.get_thread_meta("thread-1") + assert fetched is not None + assert fetched.display_name == "Updated" + assert fetched.status == "running" + assert fetched.metadata == {"topic": "finance", "region": "us"} + + by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice") + assert [thread.thread_id for thread in by_metadata] == ["thread-1"] + by_assistant = await repo.search_threads(assistant_id="agent-b") + assert [thread.thread_id for thread in by_assistant] == ["thread-2"] + + await repo.delete_thread("thread-1") + await session.commit() + + async with persistence.session_factory() as session: + repo = build_thread_meta_repository(session) + assert await repo.get_thread_meta("thread-1") is None + finally: + await persistence.aclose() + + +@pytest.mark.anyio +async def test_storage_feedback_repository_lists_and_deletes(tmp_path): + persistence = await _make_persistence(tmp_path) + try: + async with persistence.session_factory() as session: + repo = build_feedback_repository(session) + first = await repo.create_feedback( + FeedbackCreate( + feedback_id="fb-1", + run_id="run-1", + thread_id="thread-1", + rating=1, + user_id="alice", + message_id="msg-1", + comment="good", + ) + ) + second = await repo.create_feedback( + FeedbackCreate( + feedback_id="fb-2", + run_id="run-1", + thread_id="thread-1", + rating=-1, + user_id="bob", + ) + ) + await session.commit() + + async with persistence.session_factory() as session: + repo = build_feedback_repository(session) + assert await repo.get_feedback(first.feedback_id) == first + assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [ + second.feedback_id, + first.feedback_id, + ] + assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == { + "fb-1", + "fb-2", + } + assert await repo.delete_feedback("fb-1") is True + assert await repo.delete_feedback("missing") is False + with pytest.raises(ValueError, match="rating must be"): + await repo.create_feedback( + FeedbackCreate( + feedback_id="fb-bad", + run_id="run-1", + thread_id="thread-1", + rating=0, + ) + ) + await session.commit() + + async with persistence.session_factory() as session: + repo = build_feedback_repository(session) + assert await repo.get_feedback("fb-1") is None + finally: + await persistence.aclose() + + +@pytest.mark.anyio +async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path): + persistence = await _make_persistence(tmp_path) + try: + async with persistence.session_factory() as session: + repo = build_run_event_repository(session) + rows = await repo.append_batch( + [ + RunEventCreate( + thread_id="thread-1", + run_id="run-1", + user_id="alice", + event_type="message", + category="message", + content={"role": "user", "content": "hello"}, + metadata={"source": "input"}, + ), + RunEventCreate( + thread_id="thread-1", + run_id="run-1", + event_type="tool", + category="debug", + content="tool-call", + ), + RunEventCreate( + thread_id="thread-1", + run_id="run-2", + event_type="message", + category="message", + content="second", + ), + RunEventCreate( + thread_id="thread-2", + run_id="run-3", + event_type="message", + category="message", + content="other-thread", + ), + ] + ) + await session.commit() + + assert [(row.thread_id, row.seq) for row in rows] == [ + ("thread-1", 1), + ("thread-1", 2), + ("thread-1", 3), + ("thread-2", 1), + ] + assert rows[0].content == {"role": "user", "content": "hello"} + assert rows[0].metadata == {"source": "input", "content_is_json": True} + + async with persistence.session_factory() as session: + repo = build_run_event_repository(session) + messages = await repo.list_messages("thread-1", limit=2) + assert [event.seq for event in messages] == [1, 3] + assert await repo.count_messages("thread-1") == 2 + + after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5) + assert [event.seq for event in after] == [1] + before = await repo.list_messages("thread-1", before_seq=3, limit=5) + assert [event.seq for event in before] == [1] + + events = await repo.list_events("thread-1", "run-1", event_types=["tool"]) + assert [event.content for event in events] == ["tool-call"] + + assert await repo.delete_by_run("thread-1", "run-1") == 2 + assert await repo.delete_by_thread("thread-2") == 1 + await session.commit() + + async with persistence.session_factory() as session: + repo = build_run_event_repository(session) + remaining = await repo.list_events("thread-1", "run-2") + assert [event.seq for event in remaining] == [3] + assert await repo.count_messages("thread-2") == 0 + finally: + await persistence.aclose() diff --git a/backend/tests/test_storage_user_repository.py b/backend/tests/test_storage_user_repository.py new file mode 100644 index 000000000..01e535bf3 --- /dev/null +++ b/backend/tests/test_storage_user_repository.py @@ -0,0 +1,178 @@ +from __future__ import annotations + +import asyncio +import os +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from pathlib import Path +from uuid import uuid4 + +import pytest + +os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml")) + +import store.repositories.models # noqa: F401 +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from store.persistence import MappedBase +from store.repositories import UserCreate, UserNotFoundError, build_user_repository + + +@asynccontextmanager +async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]: + db_path = tmp_path / "storage-users.db" + engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}") + async with engine.begin() as conn: + await conn.run_sync(MappedBase.metadata.create_all) + + try: + yield async_sessionmaker(engine, expire_on_commit=False) + finally: + await engine.dispose() + + +async def _create_user( + session_factory: async_sessionmaker[AsyncSession], + *, + email: str = "user@example.com", + system_role: str = "user", + oauth_provider: str | None = None, + oauth_id: str | None = None, +): + async with session_factory() as session: + repo = build_user_repository(session) + user = await repo.create_user( + UserCreate( + id=str(uuid4()), + email=email, + password_hash="hash", + system_role=system_role, # type: ignore[arg-type] + oauth_provider=oauth_provider, + oauth_id=oauth_id, + ) + ) + await session.commit() + return user + + +def test_create_and_get_user_by_id_and_email(tmp_path): + async def run() -> None: + async with _session_factory(tmp_path) as session_factory: + created = await _create_user(session_factory) + + async with session_factory() as session: + repo = build_user_repository(session) + + by_id = await repo.get_user_by_id(created.id) + by_email = await repo.get_user_by_email(created.email) + + assert by_id == created + assert by_email == created + assert created.system_role == "user" + assert created.needs_setup is False + assert created.token_version == 0 + + asyncio.run(run()) + + +def test_duplicate_email_raises_value_error(tmp_path): + async def run() -> None: + async with _session_factory(tmp_path) as session_factory: + await _create_user(session_factory, email="dupe@example.com") + + async with session_factory() as session: + repo = build_user_repository(session) + with pytest.raises(ValueError, match="Email already registered"): + await repo.create_user( + UserCreate( + id=str(uuid4()), + email="dupe@example.com", + password_hash="hash", + ) + ) + + asyncio.run(run()) + + +def test_oauth_lookup_and_plain_users_without_oauth(tmp_path): + async def run() -> None: + async with _session_factory(tmp_path) as session_factory: + await _create_user(session_factory, email="local-1@example.com") + await _create_user(session_factory, email="local-2@example.com") + oauth_user = await _create_user( + session_factory, + email="oauth@example.com", + oauth_provider="github", + oauth_id="gh-123", + ) + + async with session_factory() as session: + repo = build_user_repository(session) + + assert await repo.count_users() == 3 + assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user + assert await repo.get_user_by_oauth("github", "missing") is None + + asyncio.run(run()) + + +def test_count_admins_and_get_first_admin(tmp_path): + async def run() -> None: + async with _session_factory(tmp_path) as session_factory: + await _create_user(session_factory, email="user@example.com") + admin = await _create_user( + session_factory, + email="admin@example.com", + system_role="admin", + ) + + async with session_factory() as session: + repo = build_user_repository(session) + + assert await repo.count_users() == 2 + assert await repo.count_admin_users() == 1 + assert await repo.get_first_admin() == admin + + asyncio.run(run()) + + +def test_update_user_round_trips_token_version_and_setup_state(tmp_path): + async def run() -> None: + async with _session_factory(tmp_path) as session_factory: + created = await _create_user(session_factory) + updated = created.model_copy( + update={ + "email": "renamed@example.com", + "token_version": 4, + "needs_setup": True, + } + ) + + async with session_factory() as session: + repo = build_user_repository(session) + saved = await repo.update_user(updated) + await session.commit() + + async with session_factory() as session: + repo = build_user_repository(session) + fetched = await repo.get_user_by_id(created.id) + + assert saved.email == "renamed@example.com" + assert fetched == updated + + asyncio.run(run()) + + +def test_update_missing_user_raises(tmp_path): + async def run() -> None: + async with _session_factory(tmp_path) as session_factory: + missing = UserCreate(id=str(uuid4()), email="missing@example.com") + + async with session_factory() as session: + repo = build_user_repository(session) + created_shape = await repo.create_user(missing) + await session.rollback() + + with pytest.raises(UserNotFoundError): + await repo.update_user(created_shape) + + asyncio.run(run()) diff --git a/backend/uv.lock b/backend/uv.lock index e144fb07e..53fcae74e 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -14,6 +14,7 @@ resolution-markers = [ members = [ "deer-flow", "deerflow-harness", + "deerflow-storage", ] [[package]] @@ -136,6 +137,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441, upload-time = "2026-03-31T22:00:12.791Z" }, ] +[[package]] +name = "aiomysql" +version = "0.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pymysql" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/29/e0/302aeffe8d90853556f47f3106b89c16cc2ec2a4d269bdfd82e3f4ae12cc/aiomysql-0.3.2.tar.gz", hash = "sha256:72d15ef5cfc34c03468eb41e1b90adb9fd9347b0b589114bd23ead569a02ac1a", size = 108311, upload-time = "2025-10-22T00:15:21.278Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4c/af/aae0153c3e28712adaf462328f6c7a3c196a1c1c27b491de4377dd3e6b52/aiomysql-0.3.2-py3-none-any.whl", hash = "sha256:c82c5ba04137d7afd5c693a258bea8ead2aad77101668044143a991e04632eb2", size = 71834, upload-time = "2025-10-22T00:15:15.905Z" }, +] + [[package]] name = "aiosignal" version = "1.4.0" @@ -746,6 +759,7 @@ source = { virtual = "." } dependencies = [ { name = "bcrypt" }, { name = "deerflow-harness" }, + { name = "deerflow-storage" }, { name = "dingtalk-stream" }, { name = "email-validator" }, { name = "fastapi" }, @@ -765,6 +779,7 @@ dependencies = [ [package.optional-dependencies] postgres = [ { name = "deerflow-harness", extra = ["postgres"] }, + { name = "deerflow-storage", extra = ["postgres"] }, ] [package.dev-dependencies] @@ -780,6 +795,8 @@ requires-dist = [ { name = "bcrypt", specifier = ">=4.0.0" }, { name = "deerflow-harness", editable = "packages/harness" }, { name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" }, + { name = "deerflow-storage", editable = "packages/storage" }, + { name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" }, { name = "dingtalk-stream", specifier = ">=0.24.3" }, { name = "email-validator", specifier = ">=2.0.0" }, { name = "fastapi", specifier = ">=0.115.0" }, @@ -901,6 +918,54 @@ requires-dist = [ ] provides-extras = ["ollama", "postgres", "pymupdf"] +[[package]] +name = "deerflow-storage" +version = "0.1.0" +source = { editable = "packages/storage" } +dependencies = [ + { name = "alembic" }, + { name = "dotenv" }, + { name = "langgraph" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "sqlalchemy", extra = ["asyncio"] }, +] + +[package.optional-dependencies] +mysql = [ + { name = "aiomysql" }, + { name = "langgraph-checkpoint-mysql" }, +] +postgres = [ + { name = "asyncpg" }, + { name = "langgraph-checkpoint-postgres" }, + { name = "psycopg", extra = ["binary"] }, + { name = "psycopg-pool" }, +] +sqlite = [ + { name = "aiosqlite" }, + { name = "langgraph-checkpoint-sqlite" }, +] + +[package.metadata] +requires-dist = [ + { name = "aiomysql", marker = "extra == 'mysql'", specifier = ">=0.2" }, + { name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.22.1" }, + { name = "alembic", specifier = ">=1.13" }, + { name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" }, + { name = "dotenv", specifier = ">=0.9.9" }, + { name = "langgraph", specifier = ">=1.1.9" }, + { name = "langgraph-checkpoint-mysql", marker = "extra == 'mysql'", specifier = ">=3.0.0" }, + { name = "langgraph-checkpoint-postgres", marker = "extra == 'postgres'", specifier = ">=3.0.5" }, + { name = "langgraph-checkpoint-sqlite", marker = "extra == 'sqlite'", specifier = ">=3.0.3" }, + { name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.3.3" }, + { name = "psycopg-pool", marker = "extra == 'postgres'", specifier = ">=3.3.0" }, + { name = "pydantic", specifier = ">=2.12.5" }, + { name = "pyyaml", specifier = ">=6.0.3" }, + { name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<3.0" }, +] +provides-extras = ["postgres", "mysql", "sqlite"] + [[package]] name = "defusedxml" version = "0.7.1" @@ -1914,6 +1979,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b9/5a/6dba29dd89b0a46ae21c707da0f9d17e94f27d3e481ed15bc99d6bd20aa6/langgraph_checkpoint-4.0.2-py3-none-any.whl", hash = "sha256:59b0f29216128a629c58dd07c98aa004f82f51805d5573126ffb419b753ff253", size = 51000, upload-time = "2026-04-15T21:02:59.096Z" }, ] +[[package]] +name = "langgraph-checkpoint-mysql" +version = "3.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "langgraph-checkpoint" }, + { name = "orjson" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/e0/4e/0a6c78e5d3f2ca1525903c2363e721873594b6b77dd83537a6369193c474/langgraph_checkpoint_mysql-3.0.0.tar.gz", hash = "sha256:006aaa089f4c2fbd7b2c113b800ccd3dbb95f92203e656451677256b4b4f880f", size = 213142, upload-time = "2026-01-23T11:11:15.74Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/08/68/343103a7fae05523f9cecabbec2babdb737e66b4bf6ea48ae00c685ed11c/langgraph_checkpoint_mysql-3.0.0-py3-none-any.whl", hash = "sha256:7560ccd16e7596a047e15a307cec12dbd88fdcaab45a75759e5c6adef22a27d1", size = 38009, upload-time = "2026-01-23T11:11:14.697Z" }, +] + [[package]] name = "langgraph-checkpoint-postgres" version = "3.0.5" @@ -3442,6 +3521,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" }, ] +[[package]] +name = "pymysql" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7f/ec/8d45c920e90445f0b75c590b32851853ed319763b0d8dff8d283052da8cf/pymysql-1.1.3.tar.gz", hash = "sha256:e70ebf2047a4edf6138cf79c68ad418ef620af65900aa585c5e8bfc95044d43a", size = 48207, upload-time = "2026-05-01T09:09:54.532Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/8e/dc/9085f3d6f497e9b25fb40d6e8ecef3ddbb5cf977a949b933624a299f5c16/pymysql-1.1.3-py3-none-any.whl", hash = "sha256:8164ba62c552f6105f3b11753352d0f16b90d1703ba67d81923d5a8a5d1c5289", size = 45356, upload-time = "2026-05-01T09:09:53.316Z" }, +] + [[package]] name = "pypdfium2" version = "5.7.1"