feat(storage): add storage package base
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
[project]
|
||||
name = "deerflow-storage"
|
||||
version = "0.1.0"
|
||||
description = "DeerFlow storage framework"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"dotenv>=0.9.9",
|
||||
"pydantic>=2.12.5",
|
||||
"pyyaml>=6.0.3",
|
||||
"sqlalchemy[asyncio]>=2.0,<3.0",
|
||||
"alembic>=1.13",
|
||||
"langgraph>=1.1.9",
|
||||
]
|
||||
[project.optional-dependencies]
|
||||
postgres = [
|
||||
"asyncpg>=0.29",
|
||||
"langgraph-checkpoint-postgres>=3.0.5",
|
||||
"psycopg[binary]>=3.3.3",
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
mysql = [
|
||||
"aiomysql>=0.2",
|
||||
"langgraph-checkpoint-mysql>=3.0.0",
|
||||
]
|
||||
sqlite = [
|
||||
"aiosqlite>=0.22.1",
|
||||
"langgraph-checkpoint-sqlite>=3.0.3"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["store"]
|
||||
@@ -0,0 +1,5 @@
|
||||
from .enums import DataBaseType
|
||||
|
||||
__all__ = [
|
||||
'DataBaseType',
|
||||
]
|
||||
@@ -0,0 +1,41 @@
|
||||
from enum import Enum
|
||||
from enum import IntEnum as SourceIntEnum
|
||||
from enum import StrEnum as SourceStrEnum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
T = TypeVar('T', bound=Enum)
|
||||
|
||||
|
||||
class _EnumBase:
|
||||
"""Base enum class with common utility methods."""
|
||||
|
||||
@classmethod
|
||||
def get_member_keys(cls) -> list[str]:
|
||||
"""Return a list of enum member names."""
|
||||
return list(cls.__members__.keys())
|
||||
|
||||
@classmethod
|
||||
def get_member_values(cls) -> list:
|
||||
"""Return a list of enum member values."""
|
||||
return [item.value for item in cls.__members__.values()]
|
||||
|
||||
@classmethod
|
||||
def get_member_dict(cls) -> dict[str, Any]:
|
||||
"""Return a dict mapping member names to values."""
|
||||
return {name: item.value for name, item in cls.__members__.items()}
|
||||
|
||||
|
||||
class IntEnum(_EnumBase, SourceIntEnum):
|
||||
"""Integer enum base class."""
|
||||
|
||||
|
||||
class StrEnum(_EnumBase, SourceStrEnum):
|
||||
"""String enum base class."""
|
||||
|
||||
|
||||
class DataBaseType(StrEnum):
|
||||
"""Database type."""
|
||||
|
||||
sqlite = 'sqlite'
|
||||
mysql = 'mysql'
|
||||
postgresql = 'postgresql'
|
||||
@@ -0,0 +1,290 @@
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from store.config.storage_config import StorageConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
cwd = Path.cwd().resolve()
|
||||
candidates = (
|
||||
cwd / "config.yaml",
|
||||
backend_dir / "config.yaml",
|
||||
repo_root / "config.yaml",
|
||||
)
|
||||
return tuple(dict.fromkeys(candidates))
|
||||
|
||||
|
||||
def _storage_from_database_config(config_data: dict[str, Any]) -> None:
|
||||
"""Keep the existing public `database:` config compatible with storage."""
|
||||
if "storage" in config_data:
|
||||
return
|
||||
|
||||
database = config_data.get("database")
|
||||
if not isinstance(database, dict):
|
||||
return
|
||||
|
||||
backend = database.get("backend")
|
||||
if backend == "memory":
|
||||
raise ValueError("database.backend='memory' is not supported by storage; handle memory mode before loading storage config")
|
||||
|
||||
storage: dict[str, Any] = {
|
||||
"driver": "postgres" if backend == "postgres" else backend,
|
||||
"sqlite_dir": database.get("sqlite_dir", ".deer-flow/data"),
|
||||
"echo_sql": database.get("echo_sql", False),
|
||||
"pool_size": database.get("pool_size", 5),
|
||||
}
|
||||
|
||||
postgres_url = database.get("postgres_url")
|
||||
if backend == "postgres" and isinstance(postgres_url, str) and postgres_url:
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
parsed = make_url(postgres_url)
|
||||
storage["database_url"] = postgres_url
|
||||
storage.update(
|
||||
{
|
||||
"username": parsed.username or "",
|
||||
"password": parsed.password or "",
|
||||
"host": parsed.host or "localhost",
|
||||
"port": parsed.port or 5432,
|
||||
"db_name": parsed.database or "deerflow",
|
||||
}
|
||||
)
|
||||
|
||||
config_data["storage"] = storage
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""DeerFlow application configuration."""
|
||||
|
||||
timezone: str = Field(default="UTC", description="Timezone for scheduling and timestamps (e.g. 'UTC', 'America/New_York')")
|
||||
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
|
||||
storage: StorageConfig = Field(default=StorageConfig())
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||
"""Resolve the config file path.
|
||||
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
|
||||
return path
|
||||
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
for path in _default_config_candidates():
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
"""Load and validate config from YAML. See `resolve_config_path` for path resolution."""
|
||||
resolved_path = cls.resolve_config_path(config_path)
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
config_data = yaml.safe_load(f) or {}
|
||||
|
||||
cls._check_config_version(config_data, resolved_path)
|
||||
|
||||
config_data = cls.resolve_env_variables(config_data)
|
||||
_storage_from_database_config(config_data)
|
||||
|
||||
if os.getenv("TIMEZONE"):
|
||||
config_data["timezone"] = os.getenv("TIMEZONE")
|
||||
|
||||
result = cls.model_validate(config_data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
|
||||
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
|
||||
|
||||
Emits a warning if the user's config_version is lower than the example's.
|
||||
Missing config_version is treated as version 0 (pre-versioning).
|
||||
"""
|
||||
try:
|
||||
user_version = int(config_data.get("config_version", 0))
|
||||
except (TypeError, ValueError):
|
||||
user_version = 0
|
||||
|
||||
# Find config.example.yaml by searching config.yaml's directory and its parents
|
||||
example_path = None
|
||||
search_dir = config_path.parent
|
||||
for _ in range(5): # search up to 5 levels
|
||||
candidate = search_dir / "config.example.yaml"
|
||||
if candidate.exists():
|
||||
example_path = candidate
|
||||
break
|
||||
parent = search_dir.parent
|
||||
if parent == search_dir:
|
||||
break
|
||||
search_dir = parent
|
||||
if example_path is None:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(example_path, encoding="utf-8") as f:
|
||||
example_data = yaml.safe_load(f)
|
||||
raw = example_data.get("config_version", 0) if example_data else 0
|
||||
try:
|
||||
example_version = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
example_version = 0
|
||||
except Exception:
|
||||
return
|
||||
|
||||
if user_version < example_version:
|
||||
logger.warning(
|
||||
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to "
|
||||
"merge new fields into your config.",
|
||||
user_version,
|
||||
example_version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_env_variables(cls, config: Any) -> Any:
|
||||
"""Recursively replace $VAR strings with their environment variable values (e.g. $OPENAI_API_KEY)."""
|
||||
if isinstance(config, str):
|
||||
if config.startswith("$"):
|
||||
env_value = os.getenv(config[1:])
|
||||
if env_value is None:
|
||||
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
|
||||
return env_value
|
||||
return config
|
||||
elif isinstance(config, dict):
|
||||
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
|
||||
elif isinstance(config, list):
|
||||
return [cls.resolve_env_variables(item) for item in config]
|
||||
return config
|
||||
|
||||
|
||||
|
||||
_app_config: AppConfig | None = None
|
||||
_app_config_path: Path | None = None
|
||||
_app_config_mtime: float | None = None
|
||||
_app_config_is_custom = False
|
||||
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack",
|
||||
default=())
|
||||
|
||||
|
||||
def _get_config_mtime(config_path: Path) -> float | None:
|
||||
"""Get the modification time of a config file if it exists."""
|
||||
try:
|
||||
return config_path.stat().st_mtime
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Load config from disk and refresh cache metadata."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path(config_path)
|
||||
_app_config = AppConfig.from_file(str(resolved_path))
|
||||
_app_config_path = resolved_path
|
||||
_app_config_mtime = _get_config_mtime(resolved_path)
|
||||
_app_config_is_custom = False
|
||||
return _app_config
|
||||
|
||||
|
||||
def get_app_config() -> AppConfig:
|
||||
"""Get the DeerFlow config instance.
|
||||
|
||||
Returns a cached singleton instance and automatically reloads it when the
|
||||
underlying config file path or modification time changes. Use
|
||||
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
|
||||
the cache.
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime
|
||||
|
||||
runtime_override = _current_app_config.get()
|
||||
if runtime_override is not None:
|
||||
return runtime_override
|
||||
|
||||
if _app_config is not None and _app_config_is_custom:
|
||||
return _app_config
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path()
|
||||
current_mtime = _get_config_mtime(resolved_path)
|
||||
|
||||
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
|
||||
if should_reload:
|
||||
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
|
||||
logger.info(
|
||||
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
|
||||
_app_config_mtime,
|
||||
current_mtime,
|
||||
)
|
||||
_load_and_cache_app_config(str(resolved_path))
|
||||
return _app_config
|
||||
|
||||
|
||||
def reload_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Force reload from file and update the cache."""
|
||||
return _load_and_cache_app_config(config_path)
|
||||
|
||||
|
||||
def reset_app_config() -> None:
|
||||
"""Clear the cache so the next `get_app_config()` reloads from file."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = None
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = False
|
||||
|
||||
|
||||
def set_app_config(config: AppConfig) -> None:
|
||||
"""Inject a config instance directly, bypassing file loading (for testing)."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = config
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = True
|
||||
|
||||
|
||||
def peek_current_app_config() -> AppConfig | None:
|
||||
"""Return the runtime-scoped AppConfig override, if one is active."""
|
||||
return _current_app_config.get()
|
||||
|
||||
|
||||
def push_current_app_config(config: AppConfig) -> None:
|
||||
"""Push a runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
_current_app_config_stack.set(stack + (_current_app_config.get(),))
|
||||
_current_app_config.set(config)
|
||||
|
||||
|
||||
def pop_current_app_config() -> None:
|
||||
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
if not stack:
|
||||
_current_app_config.set(None)
|
||||
return
|
||||
previous = stack[-1]
|
||||
_current_app_config_stack.set(stack[:-1])
|
||||
_current_app_config.set(previous)
|
||||
@@ -0,0 +1,72 @@
|
||||
"""Unified storage backend configuration for checkpointer and application data.
|
||||
|
||||
SQLite: checkpointer → {sqlite_dir}/checkpoints.db, app → {sqlite_dir}/deerflow.db
|
||||
(separate files to avoid write-lock contention)
|
||||
Postgres: shared URL, independent connection pools per layer.
|
||||
|
||||
Sensitive values use $VAR syntax resolved by AppConfig.resolve_env_variables()
|
||||
before this config is instantiated.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _strip_legacy_state_prefix(path: str) -> str:
|
||||
"""Keep old .deer-flow/* config values compatible with Paths.base_dir."""
|
||||
prefix = ".deer-flow/"
|
||||
if path == ".deer-flow":
|
||||
return "."
|
||||
if path.startswith(prefix):
|
||||
return path[len(prefix):]
|
||||
return path
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field(
|
||||
default="sqlite",
|
||||
description="Storage driver for both checkpointer and application data. "
|
||||
"'sqlite' for single-node deployment (default),"
|
||||
"'postgres' for production multi-node deployment, "
|
||||
"'mysql' for MySQL databases.",
|
||||
)
|
||||
sqlite_dir: str = Field(
|
||||
default=".deer-flow/data",
|
||||
description="Directory for SQLite .db files (sqlite driver only).",
|
||||
)
|
||||
username: str = Field(default="", description="db username ")
|
||||
password: str = Field(default="", description="db password. Use $VAR syntax in config.yaml to read from .env.")
|
||||
host: str = Field(default="localhost", description="db host.")
|
||||
port: int = Field(default=5432, description="db port.")
|
||||
db_name: str = Field(default="deerflow", description="db database name.")
|
||||
database_url: str = Field(default="", description="Complete SQLAlchemy database URL. Takes precedence for non-SQLite drivers.")
|
||||
sqlite_db_path: str = Field(default=".deer-flow/data", description="Directory for SQLite .db files (sqlite driver only).")
|
||||
echo_sql: bool = Field(default=False, description="Log all SQL statements (debug only).")
|
||||
pool_size: int = Field(default=5, description="Connection pool size per layer.")
|
||||
|
||||
# -- Derived helpers (not user-configured) --
|
||||
|
||||
@property
|
||||
def _resolved_sqlite_dir(self) -> str:
|
||||
"""Resolve sqlite_dir to an absolute path under DeerFlow's base dir."""
|
||||
from pathlib import Path
|
||||
|
||||
path = Path(self.sqlite_dir)
|
||||
if path.is_absolute():
|
||||
return str(path.resolve())
|
||||
|
||||
try:
|
||||
from deerflow.config.paths import resolve_path
|
||||
|
||||
return str(resolve_path(_strip_legacy_state_prefix(self.sqlite_dir)))
|
||||
except ImportError:
|
||||
return str(path.resolve())
|
||||
|
||||
@property
|
||||
def sqlite_storage_path(self) -> str:
|
||||
"""SQLite file path for storage-owned app data and checkpointer."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
||||
@@ -0,0 +1,32 @@
|
||||
from store.persistence.base_model import (
|
||||
Base,
|
||||
DataClassBase,
|
||||
DateTimeMixin,
|
||||
MappedBase,
|
||||
TimeZone,
|
||||
UniversalText,
|
||||
id_key,
|
||||
)
|
||||
|
||||
from .factory import (
|
||||
create_persistence,
|
||||
create_persistence_from_database_config,
|
||||
create_persistence_from_storage_config,
|
||||
storage_config_from_database_config,
|
||||
)
|
||||
from .types import AppPersistence
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"DataClassBase",
|
||||
"DateTimeMixin",
|
||||
"MappedBase",
|
||||
"TimeZone",
|
||||
"UniversalText",
|
||||
"id_key",
|
||||
"create_persistence",
|
||||
"create_persistence_from_database_config",
|
||||
"create_persistence_from_storage_config",
|
||||
"storage_config_from_database_config",
|
||||
"AppPersistence",
|
||||
]
|
||||
@@ -0,0 +1,107 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.utils import get_timezone
|
||||
|
||||
timezone = get_timezone()
|
||||
app_config = get_app_config()
|
||||
|
||||
# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT)
|
||||
_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger
|
||||
|
||||
id_key = Annotated[
|
||||
int,
|
||||
mapped_column(
|
||||
_id_type,
|
||||
primary_key=True,
|
||||
unique=True,
|
||||
index=True,
|
||||
autoincrement=True,
|
||||
sort_order=-999,
|
||||
comment="Primary key ID",
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
class UniversalText(TypeDecorator[str]):
|
||||
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
||||
|
||||
impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
|
||||
class TimeZone(TypeDecorator[datetime]):
|
||||
"""Timezone-aware datetime type compatible with PostgreSQL and MySQL."""
|
||||
|
||||
impl = DateTime(timezone=True)
|
||||
cache_ok = True
|
||||
|
||||
@property
|
||||
def python_type(self) -> type[datetime]:
|
||||
return datetime
|
||||
|
||||
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
||||
value = timezone.from_datetime(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
if value is not None and value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.tz_info)
|
||||
return value
|
||||
|
||||
|
||||
class DateTimeMixin(MappedAsDataclass):
|
||||
"""Mixin that adds created_time / updated_time columns."""
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=timezone.now,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
onupdate=timezone.now,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
|
||||
class MappedBase(AsyncAttrs, DeclarativeBase):
|
||||
"""Async-capable declarative base for all ORM models."""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(self) -> str:
|
||||
return self.__name__.lower()
|
||||
|
||||
@declared_attr.directive
|
||||
def __table_args__(self) -> dict:
|
||||
return {"comment": self.__doc__ or ""}
|
||||
|
||||
|
||||
class DataClassBase(MappedAsDataclass, MappedBase):
|
||||
"""Declarative base with native dataclass integration."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class Base(DataClassBase, DateTimeMixin):
|
||||
"""Declarative dataclass base with created_time / updated_time columns."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -0,0 +1,9 @@
|
||||
from .mysql import build_mysql_persistence
|
||||
from .postgres import build_postgres_persistence
|
||||
from .sqlite import build_sqlite_persistence
|
||||
|
||||
__all__ = [
|
||||
"build_postgres_persistence",
|
||||
"build_mysql_persistence",
|
||||
"build_sqlite_persistence",
|
||||
]
|
||||
@@ -0,0 +1,75 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _validate_mysql_driver(db_url: str) -> str:
|
||||
url = make_url(db_url)
|
||||
driver = url.get_driver_name()
|
||||
|
||||
if driver not in {"aiomysql", "asyncmy"}:
|
||||
raise ValueError(
|
||||
"MySQL persistence requires async SQLAlchemy driver "
|
||||
f"(aiomysql/asyncmy), got: {driver!r}"
|
||||
)
|
||||
return driver
|
||||
|
||||
|
||||
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
_validate_mysql_driver(db_url)
|
||||
|
||||
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(db_url)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(db_url)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,69 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001
|
||||
cursor = dbapi_conn.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL;")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||||
cursor.execute("PRAGMA foreign_keys=ON;")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,121 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.config.storage_config import StorageConfig
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def storage_config_from_database_config(database_config: Any) -> StorageConfig:
|
||||
"""Convert the existing public DatabaseConfig shape to StorageConfig.
|
||||
|
||||
Storage only owns durable database-backed persistence. The app bridge
|
||||
should handle memory mode before calling into this package.
|
||||
"""
|
||||
backend = getattr(database_config, "backend", None)
|
||||
if backend == "sqlite":
|
||||
return StorageConfig(
|
||||
driver="sqlite",
|
||||
sqlite_dir=getattr(database_config, "sqlite_dir", ".deer-flow/data"),
|
||||
echo_sql=getattr(database_config, "echo_sql", False),
|
||||
pool_size=getattr(database_config, "pool_size", 5),
|
||||
)
|
||||
|
||||
if backend == "postgres":
|
||||
postgres_url = getattr(database_config, "postgres_url", "")
|
||||
if not postgres_url:
|
||||
raise ValueError("database.postgres_url is required when database.backend is 'postgres'")
|
||||
parsed = make_url(postgres_url)
|
||||
return StorageConfig(
|
||||
driver="postgres",
|
||||
database_url=postgres_url,
|
||||
username=parsed.username or "",
|
||||
password=parsed.password or "",
|
||||
host=parsed.host or "localhost",
|
||||
port=parsed.port or 5432,
|
||||
db_name=parsed.database or "deerflow",
|
||||
echo_sql=getattr(database_config, "echo_sql", False),
|
||||
pool_size=getattr(database_config, "pool_size", 5),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported database backend for storage persistence: {backend!r}")
|
||||
|
||||
|
||||
def _create_database_url(storage_config: StorageConfig) -> URL:
|
||||
"""Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres)."""
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
driver = "sqlite+aiosqlite"
|
||||
elif storage_config.driver == DataBaseType.mysql:
|
||||
driver = "mysql+aiomysql"
|
||||
elif storage_config.driver in (DataBaseType.postgresql, "postgres"):
|
||||
driver = "postgresql+asyncpg"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database driver: {storage_config.driver}")
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
import os
|
||||
|
||||
db_path = storage_config.sqlite_storage_path
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
database=db_path,
|
||||
)
|
||||
elif storage_config.database_url:
|
||||
url = make_url(storage_config.database_url)
|
||||
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
|
||||
url = url.set(drivername="postgresql+asyncpg")
|
||||
else:
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
username=storage_config.username,
|
||||
password=storage_config.password,
|
||||
host=storage_config.host,
|
||||
port=storage_config.port,
|
||||
database=storage_config.db_name or "deerflow",
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
async def create_persistence_from_storage_config(storage_config: StorageConfig) -> AppPersistence:
|
||||
from .drivers.mysql import build_mysql_persistence
|
||||
from .drivers.postgres import build_postgres_persistence
|
||||
from .drivers.sqlite import build_sqlite_persistence
|
||||
|
||||
driver = storage_config.driver
|
||||
db_url = _create_database_url(storage_config)
|
||||
|
||||
if driver in ("postgres", "postgresql"):
|
||||
return await build_postgres_persistence(
|
||||
db_url,
|
||||
echo=storage_config.echo_sql,
|
||||
pool_size=storage_config.pool_size,
|
||||
)
|
||||
|
||||
if driver == "mysql":
|
||||
return await build_mysql_persistence(
|
||||
db_url,
|
||||
echo=storage_config.echo_sql,
|
||||
pool_size=storage_config.pool_size,
|
||||
)
|
||||
|
||||
if driver == "sqlite":
|
||||
return await build_sqlite_persistence(db_url, echo=storage_config.echo_sql)
|
||||
|
||||
raise ValueError(f"Unsupported database driver: {driver}")
|
||||
|
||||
|
||||
async def create_persistence_from_database_config(database_config: Any) -> AppPersistence:
|
||||
storage_config = storage_config_from_database_config(database_config)
|
||||
return await create_persistence_from_storage_config(storage_config)
|
||||
|
||||
|
||||
async def create_persistence() -> AppPersistence:
|
||||
app_config = get_app_config()
|
||||
return await create_persistence_from_storage_config(app_config.storage)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .close import close_in_order
|
||||
|
||||
__all__ = ["close_in_order"]
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
AsyncCloser = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
async def close_in_order(*closers: AsyncCloser) -> None:
|
||||
"""
|
||||
Run async closers in order and raise the first error, if any.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Used to keep driver-specific close logic readable.
|
||||
- We intentionally do not stop at first failure, so later resources
|
||||
still get a chance to close.
|
||||
"""
|
||||
first_error: Exception | None = None
|
||||
|
||||
for closer in closers:
|
||||
try:
|
||||
await closer()
|
||||
except Exception as exc:
|
||||
if first_error is None:
|
||||
first_error = exc
|
||||
|
||||
if first_error is not None:
|
||||
raise first_error
|
||||
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
AsyncSetup = Callable[[], Awaitable[None]]
|
||||
AsyncClose = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AppPersistence:
|
||||
"""
|
||||
Unified runtime persistence bundle.
|
||||
"""
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
setup: AsyncSetup
|
||||
aclose: AsyncClose
|
||||
@@ -0,0 +1,51 @@
|
||||
from store.repositories.contracts import (
|
||||
Feedback,
|
||||
FeedbackAggregate,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
Run,
|
||||
RunCreate,
|
||||
RunEvent,
|
||||
RunEventCreate,
|
||||
RunEventRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
User,
|
||||
UserCreate,
|
||||
UserNotFoundError,
|
||||
UserRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.factory import (
|
||||
build_feedback_repository,
|
||||
build_run_event_repository,
|
||||
build_run_repository,
|
||||
build_thread_meta_repository,
|
||||
build_user_repository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Feedback",
|
||||
"FeedbackAggregate",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
"User",
|
||||
"UserCreate",
|
||||
"UserNotFoundError",
|
||||
"UserRepositoryProtocol",
|
||||
"build_run_repository",
|
||||
"build_run_event_repository",
|
||||
"build_thread_meta_repository",
|
||||
"build_feedback_repository",
|
||||
"build_user_repository",
|
||||
]
|
||||
@@ -0,0 +1,47 @@
|
||||
from store.repositories.contracts.feedback import (
|
||||
Feedback,
|
||||
FeedbackAggregate,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.run import (
|
||||
Run,
|
||||
RunCreate,
|
||||
RunRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.run_event import (
|
||||
RunEvent,
|
||||
RunEventCreate,
|
||||
RunEventRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.user import (
|
||||
User,
|
||||
UserCreate,
|
||||
UserNotFoundError,
|
||||
UserRepositoryProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Feedback",
|
||||
"FeedbackAggregate",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
"User",
|
||||
"UserCreate",
|
||||
"UserNotFoundError",
|
||||
"UserRepositoryProtocol",
|
||||
]
|
||||
@@ -0,0 +1,62 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Protocol, TypedDict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class FeedbackCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
rating: int
|
||||
user_id: str | None = None
|
||||
message_id: str | None = None
|
||||
comment: str | None = None
|
||||
|
||||
|
||||
class Feedback(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
rating: int
|
||||
user_id: str | None
|
||||
message_id: str | None
|
||||
comment: str | None
|
||||
created_time: datetime
|
||||
|
||||
|
||||
class FeedbackAggregate(TypedDict):
|
||||
run_id: str
|
||||
total: int
|
||||
positive: int
|
||||
negative: int
|
||||
|
||||
|
||||
class FeedbackRepositoryProtocol(Protocol):
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback: ...
|
||||
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback: ...
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None: ...
|
||||
async def list_feedback_by_run(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
thread_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]: ...
|
||||
async def list_feedback_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]: ...
|
||||
async def delete_feedback(self, feedback_id: str) -> bool: ...
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool: ...
|
||||
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate: ...
|
||||
@@ -0,0 +1,85 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RunCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
status: str = "pending"
|
||||
model_name: str | None = None
|
||||
multitask_strategy: str = "reject"
|
||||
error: str | None = None
|
||||
follow_up_to_run_id: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
created_time: datetime | None = None
|
||||
|
||||
|
||||
class Run(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
status: str
|
||||
model_name: str | None
|
||||
multitask_strategy: str
|
||||
error: str | None
|
||||
follow_up_to_run_id: str | None
|
||||
metadata: dict[str, Any]
|
||||
kwargs: dict[str, Any]
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_tokens: int
|
||||
llm_call_count: int
|
||||
lead_agent_tokens: int
|
||||
subagent_tokens: int
|
||||
middleware_tokens: int
|
||||
message_count: int
|
||||
first_human_message: str | None
|
||||
last_ai_message: str | None
|
||||
created_time: datetime
|
||||
updated_time: datetime | None
|
||||
|
||||
|
||||
class RunRepositoryProtocol(Protocol):
|
||||
async def create_run(self, data: RunCreate) -> Run: ...
|
||||
async def get_run(self, run_id: str) -> Run | None: ...
|
||||
async def list_runs_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[Run]: ...
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None: ...
|
||||
async def delete_run(self, run_id: str) -> None: ...
|
||||
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]: ...
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None: ...
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: ...
|
||||
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RunEventCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str
|
||||
run_id: str
|
||||
user_id: str | None = None
|
||||
event_type: str
|
||||
category: str
|
||||
content: Any = ""
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: datetime | None = None
|
||||
|
||||
|
||||
class RunEvent(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
thread_id: str
|
||||
run_id: str
|
||||
user_id: str | None
|
||||
event_type: str
|
||||
category: str
|
||||
content: Any
|
||||
metadata: dict[str, Any]
|
||||
seq: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RunEventRepositoryProtocol(Protocol):
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]: ...
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]: ...
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]: ...
|
||||
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int: ...
|
||||
|
||||
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int: ...
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int: ...
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class ThreadMetaCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
display_name: str | None = None
|
||||
status: str = "idle"
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ThreadMeta(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
display_name: str | None
|
||||
status: str
|
||||
metadata: dict[str, Any]
|
||||
created_time: datetime
|
||||
updated_time: datetime | None
|
||||
|
||||
|
||||
class ThreadMetaRepositoryProtocol(Protocol):
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta: ...
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None: ...
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None: ...
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None: ...
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]: ...
|
||||
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class UserNotFoundError(LookupError):
|
||||
"""Raised when an update targets a user row that no longer exists."""
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: str
|
||||
email: str
|
||||
password_hash: str | None = None
|
||||
system_role: Literal["admin", "user"] = "user"
|
||||
created_at: datetime | None = None
|
||||
oauth_provider: str | None = None
|
||||
oauth_id: str | None = None
|
||||
needs_setup: bool = False
|
||||
token_version: int = 0
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
id: str
|
||||
email: str
|
||||
password_hash: str | None
|
||||
system_role: Literal["admin", "user"]
|
||||
created_at: datetime
|
||||
oauth_provider: str | None
|
||||
oauth_id: str | None
|
||||
needs_setup: bool
|
||||
token_version: int
|
||||
|
||||
|
||||
class UserRepositoryProtocol(Protocol):
|
||||
async def create_user(self, data: UserCreate) -> User: ...
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None: ...
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None: ...
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: ...
|
||||
|
||||
async def get_first_admin(self) -> User | None: ...
|
||||
|
||||
async def update_user(self, data: User) -> User: ...
|
||||
|
||||
async def count_users(self) -> int: ...
|
||||
|
||||
async def count_admin_users(self) -> int: ...
|
||||
@@ -0,0 +1,13 @@
|
||||
from store.repositories.db.feedback import DbFeedbackRepository
|
||||
from store.repositories.db.run import DbRunRepository
|
||||
from store.repositories.db.run_event import DbRunEventRepository
|
||||
from store.repositories.db.thread_meta import DbThreadMetaRepository
|
||||
from store.repositories.db.user import DbUserRepository
|
||||
|
||||
__all__ = [
|
||||
"DbFeedbackRepository",
|
||||
"DbRunRepository",
|
||||
"DbRunEventRepository",
|
||||
"DbThreadMetaRepository",
|
||||
"DbUserRepository",
|
||||
]
|
||||
@@ -0,0 +1,146 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import case, delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.feedback import Feedback, FeedbackAggregate, FeedbackCreate, FeedbackRepositoryProtocol
|
||||
from store.repositories.models.feedback import Feedback as FeedbackModel
|
||||
|
||||
|
||||
def _to_feedback(m: FeedbackModel) -> Feedback:
|
||||
return Feedback(
|
||||
feedback_id=m.feedback_id,
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
rating=m.rating,
|
||||
user_id=m.user_id,
|
||||
message_id=m.message_id,
|
||||
comment=m.comment,
|
||||
created_time=m.created_time,
|
||||
)
|
||||
|
||||
|
||||
class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
if data.rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||
model = FeedbackModel(
|
||||
feedback_id=data.feedback_id,
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
rating=data.rating,
|
||||
user_id=data.user_id,
|
||||
message_id=data.message_id,
|
||||
comment=data.comment,
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_feedback(model)
|
||||
|
||||
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
if data.rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel).where(
|
||||
FeedbackModel.thread_id == data.thread_id,
|
||||
FeedbackModel.run_id == data.run_id,
|
||||
FeedbackModel.user_id == data.user_id,
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if model is None:
|
||||
return await self.create_feedback(data)
|
||||
|
||||
model.rating = data.rating
|
||||
model.message_id = data.message_id
|
||||
model.comment = data.comment
|
||||
model.created_time = datetime.now(UTC)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_feedback(model)
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_feedback(model) if model else None
|
||||
|
||||
async def list_feedback_by_run(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
thread_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
stmt = select(FeedbackModel).where(FeedbackModel.run_id == run_id)
|
||||
if thread_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.thread_id == thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||
stmt = stmt.order_by(FeedbackModel.created_time.desc())
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def list_feedback_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
stmt = select(FeedbackModel).where(FeedbackModel.thread_id == thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||
stmt = stmt.order_by(FeedbackModel.created_time.desc())
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||
existing = await self.get_feedback(feedback_id)
|
||||
if existing is None:
|
||||
return False
|
||||
await self._session.execute(
|
||||
delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
return True
|
||||
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||
stmt = select(FeedbackModel).where(
|
||||
FeedbackModel.thread_id == thread_id,
|
||||
FeedbackModel.run_id == run_id,
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
if model is None:
|
||||
return False
|
||||
await self._session.delete(model)
|
||||
return True
|
||||
|
||||
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
|
||||
stmt = select(
|
||||
func.count().label("total"),
|
||||
func.coalesce(func.sum(case((FeedbackModel.rating == 1, 1), else_=0)), 0).label("positive"),
|
||||
func.coalesce(func.sum(case((FeedbackModel.rating == -1, 1), else_=0)), 0).label("negative"),
|
||||
).where(FeedbackModel.thread_id == thread_id, FeedbackModel.run_id == run_id)
|
||||
row = (await self._session.execute(stmt)).one()
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"total": int(row.total),
|
||||
"positive": int(row.positive),
|
||||
"negative": int(row.negative),
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol
|
||||
from store.repositories.models.run import Run as RunModel
|
||||
|
||||
|
||||
def _to_run(m: RunModel) -> Run:
|
||||
return Run(
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
status=m.status,
|
||||
model_name=m.model_name,
|
||||
multitask_strategy=m.multitask_strategy,
|
||||
error=m.error,
|
||||
follow_up_to_run_id=m.follow_up_to_run_id,
|
||||
metadata=dict(m.meta or {}),
|
||||
kwargs=dict(m.kwargs or {}),
|
||||
total_input_tokens=m.total_input_tokens,
|
||||
total_output_tokens=m.total_output_tokens,
|
||||
total_tokens=m.total_tokens,
|
||||
llm_call_count=m.llm_call_count,
|
||||
lead_agent_tokens=m.lead_agent_tokens,
|
||||
subagent_tokens=m.subagent_tokens,
|
||||
middleware_tokens=m.middleware_tokens,
|
||||
message_count=m.message_count,
|
||||
first_human_message=m.first_human_message,
|
||||
last_ai_message=m.last_ai_message,
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbRunRepository(RunRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_run(self, data: RunCreate) -> Run:
|
||||
model = RunModel(
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
status=data.status,
|
||||
model_name=data.model_name,
|
||||
multitask_strategy=data.multitask_strategy,
|
||||
error=data.error,
|
||||
follow_up_to_run_id=data.follow_up_to_run_id,
|
||||
meta=dict(data.metadata),
|
||||
kwargs=dict(data.kwargs),
|
||||
)
|
||||
if data.created_time is not None:
|
||||
model.created_time = data.created_time
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_run(model)
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
result = await self._session.execute(
|
||||
select(RunModel).where(RunModel.run_id == run_id)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_run(model) if model else None
|
||||
|
||||
async def list_runs_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[Run]:
|
||||
stmt = select(RunModel).where(RunModel.thread_id == thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunModel.user_id == user_id)
|
||||
stmt = stmt.order_by(RunModel.created_time.desc()).limit(limit).offset(offset)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_status(
|
||||
self, run_id: str, status: str, *, error: str | None = None
|
||||
) -> None:
|
||||
values: dict = {"status": status}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
|
||||
|
||||
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
|
||||
if before is None:
|
||||
before_dt = datetime.now().astimezone()
|
||||
elif isinstance(before, datetime):
|
||||
before_dt = before
|
||||
else:
|
||||
before_dt = datetime.fromisoformat(before)
|
||||
|
||||
result = await self._session.execute(
|
||||
select(RunModel)
|
||||
.where(RunModel.status == "pending", RunModel.created_time <= before_dt)
|
||||
.order_by(RunModel.created_time.asc())
|
||||
)
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
values = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
}
|
||||
if first_human_message is not None:
|
||||
values["first_human_message"] = first_human_message[:2000]
|
||||
if last_ai_message is not None:
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = RunModel.status.in_(("success", "error"))
|
||||
stmt = (
|
||||
select(
|
||||
func.coalesce(RunModel.model_name, "unknown").label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
|
||||
func.coalesce(func.sum(RunModel.total_output_tokens), 0).label("total_output_tokens"),
|
||||
func.coalesce(func.sum(RunModel.lead_agent_tokens), 0).label("lead_agent"),
|
||||
func.coalesce(func.sum(RunModel.subagent_tokens), 0).label("subagent"),
|
||||
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(RunModel.thread_id == thread_id, completed)
|
||||
.group_by(func.coalesce(RunModel.model_name, "unknown"))
|
||||
)
|
||||
|
||||
rows = (await self._session.execute(stmt)).all()
|
||||
total_tokens = total_input = total_output = total_runs = 0
|
||||
lead_agent = subagent = middleware = 0
|
||||
by_model: dict[str, dict] = {}
|
||||
for row in rows:
|
||||
by_model[row.model] = {"tokens": row.total_tokens, "runs": row.runs}
|
||||
total_tokens += row.total_tokens
|
||||
total_input += row.total_input_tokens
|
||||
total_output += row.total_output_tokens
|
||||
total_runs += row.runs
|
||||
lead_agent += row.lead_agent
|
||||
subagent += row.subagent
|
||||
middleware += row.middleware
|
||||
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_runs": total_runs,
|
||||
"by_model": by_model,
|
||||
"by_caller": {
|
||||
"lead_agent": lead_agent,
|
||||
"subagent": subagent,
|
||||
"middleware": middleware,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,195 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
||||
from store.repositories.models.run_event import RunEvent as RunEventModel
|
||||
|
||||
|
||||
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
if not isinstance(content, str):
|
||||
next_metadata = {**metadata, "content_is_json": True}
|
||||
if isinstance(content, dict):
|
||||
next_metadata["content_is_dict"] = True
|
||||
return json.dumps(content, default=str, ensure_ascii=False), next_metadata
|
||||
return content, metadata
|
||||
|
||||
|
||||
def _deserialize_content(content: str, metadata: dict[str, Any]) -> Any:
|
||||
if not (metadata.get("content_is_json") or metadata.get("content_is_dict")):
|
||||
return content
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
|
||||
def _to_run_event(model: RunEventModel) -> RunEvent:
|
||||
raw_metadata = dict(model.meta or {})
|
||||
metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"}
|
||||
return RunEvent(
|
||||
thread_id=model.thread_id,
|
||||
run_id=model.run_id,
|
||||
user_id=model.user_id,
|
||||
event_type=model.event_type,
|
||||
category=model.category,
|
||||
content=_deserialize_content(model.content, raw_metadata),
|
||||
metadata=metadata,
|
||||
seq=model.seq,
|
||||
created_at=model.created_at,
|
||||
)
|
||||
|
||||
|
||||
class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||
if not events:
|
||||
return []
|
||||
|
||||
thread_ids = {event.thread_id for event in events}
|
||||
seq_by_thread: dict[str, int] = {}
|
||||
for thread_id in thread_ids:
|
||||
max_seq = await self._session.scalar(
|
||||
select(func.max(RunEventModel.seq))
|
||||
.where(RunEventModel.thread_id == thread_id)
|
||||
.with_for_update()
|
||||
)
|
||||
seq_by_thread[thread_id] = max_seq or 0
|
||||
|
||||
rows: list[RunEventModel] = []
|
||||
|
||||
for event in events:
|
||||
seq_by_thread[event.thread_id] += 1
|
||||
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
||||
row = RunEventModel(
|
||||
thread_id=event.thread_id,
|
||||
run_id=event.run_id,
|
||||
user_id=event.user_id,
|
||||
seq=seq_by_thread[event.thread_id],
|
||||
event_type=event.event_type,
|
||||
category=event.category,
|
||||
content=content,
|
||||
meta=metadata,
|
||||
)
|
||||
if event.created_at is not None:
|
||||
row.created_at = event.created_at
|
||||
self._session.add(row)
|
||||
rows.append(row)
|
||||
|
||||
await self._session.flush()
|
||||
return [_to_run_event(row) for row in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
if event_types is not None:
|
||||
stmt = stmt.where(RunEventModel.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = (
|
||||
select(RunEventModel)
|
||||
.where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(RunEventModel)
|
||||
.where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(stmt)
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
conditions = [RunEventModel.thread_id == thread_id]
|
||||
if user_id is not None:
|
||||
conditions.append(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
|
||||
await self._session.execute(delete(RunEventModel).where(*conditions))
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
|
||||
conditions = [RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id]
|
||||
if user_id is not None:
|
||||
conditions.append(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
|
||||
await self._session.execute(delete(RunEventModel).where(*conditions))
|
||||
return int(count or 0)
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol
|
||||
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||
|
||||
|
||||
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||
return ThreadMeta(
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
display_name=m.display_name,
|
||||
status=m.status,
|
||||
metadata=dict(m.meta or {}),
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
model = ThreadMetaModel(
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
display_name=data.display_name,
|
||||
status=data.status,
|
||||
meta=dict(data.metadata),
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_thread_meta(model)
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_thread_meta(model) if model else None
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
values: dict = {}
|
||||
if display_name is not None:
|
||||
values["display_name"] = display_name
|
||||
if status is not None:
|
||||
values["status"] = status
|
||||
if metadata is not None:
|
||||
values["meta"] = dict(metadata)
|
||||
if not values:
|
||||
return
|
||||
await self._session.execute(
|
||||
update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
stmt = select(ThreadMetaModel)
|
||||
|
||||
if status is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.status == status)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.user_id == user_id)
|
||||
if assistant_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value))
|
||||
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc())
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_thread_meta(m) for m in result.scalars().all()]
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.user import User, UserCreate, UserNotFoundError, UserRepositoryProtocol
|
||||
from store.repositories.models.user import User as UserModel
|
||||
|
||||
|
||||
def _to_user(model: UserModel) -> User:
|
||||
return User(
|
||||
id=model.id,
|
||||
email=model.email,
|
||||
password_hash=model.password_hash,
|
||||
system_role=model.system_role, # type: ignore[arg-type]
|
||||
created_at=model.created_at,
|
||||
oauth_provider=model.oauth_provider,
|
||||
oauth_id=model.oauth_id,
|
||||
needs_setup=model.needs_setup,
|
||||
token_version=model.token_version,
|
||||
)
|
||||
|
||||
|
||||
class DbUserRepository(UserRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_user(self, data: UserCreate) -> User:
|
||||
model = UserModel(
|
||||
id=data.id,
|
||||
email=data.email,
|
||||
system_role=data.system_role,
|
||||
password_hash=data.password_hash,
|
||||
oauth_provider=data.oauth_provider,
|
||||
oauth_id=data.oauth_id,
|
||||
needs_setup=data.needs_setup,
|
||||
token_version=data.token_version,
|
||||
)
|
||||
if data.created_at is not None:
|
||||
model.created_at = data.created_at
|
||||
self._session.add(model)
|
||||
try:
|
||||
await self._session.flush()
|
||||
except IntegrityError as exc:
|
||||
await self._session.rollback()
|
||||
raise ValueError(f"Email already registered: {data.email}") from exc
|
||||
await self._session.refresh(model)
|
||||
return _to_user(model)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
model = await self._session.get(UserModel, user_id)
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
result = await self._session.execute(select(UserModel).where(UserModel.email == email))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
result = await self._session.execute(
|
||||
select(UserModel).where(
|
||||
UserModel.oauth_provider == provider,
|
||||
UserModel.oauth_id == oauth_id,
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def get_first_admin(self) -> User | None:
|
||||
result = await self._session.execute(select(UserModel).where(UserModel.system_role == "admin").limit(1))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def update_user(self, data: User) -> User:
|
||||
model = await self._session.get(UserModel, data.id)
|
||||
if model is None:
|
||||
raise UserNotFoundError(f"User {data.id} no longer exists")
|
||||
|
||||
model.email = data.email
|
||||
model.password_hash = data.password_hash
|
||||
model.system_role = data.system_role
|
||||
model.oauth_provider = data.oauth_provider
|
||||
model.oauth_id = data.oauth_id
|
||||
model.needs_setup = data.needs_setup
|
||||
model.token_version = data.token_version
|
||||
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_user(model)
|
||||
|
||||
async def count_users(self) -> int:
|
||||
count = await self._session.scalar(select(func.count()).select_from(UserModel))
|
||||
return int(count or 0)
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
count = await self._session.scalar(select(func.count()).select_from(UserModel).where(UserModel.system_role == "admin"))
|
||||
return int(count or 0)
|
||||
@@ -0,0 +1,36 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories import (
|
||||
FeedbackRepositoryProtocol,
|
||||
RunEventRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
UserRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.db import (
|
||||
DbFeedbackRepository,
|
||||
DbRunEventRepository,
|
||||
DbRunRepository,
|
||||
DbThreadMetaRepository,
|
||||
DbUserRepository,
|
||||
)
|
||||
|
||||
|
||||
def build_thread_meta_repository(session: AsyncSession) -> ThreadMetaRepositoryProtocol:
|
||||
return DbThreadMetaRepository(session)
|
||||
|
||||
|
||||
def build_run_repository(session: AsyncSession) -> RunRepositoryProtocol:
|
||||
return DbRunRepository(session)
|
||||
|
||||
|
||||
def build_feedback_repository(session: AsyncSession) -> FeedbackRepositoryProtocol:
|
||||
return DbFeedbackRepository(session)
|
||||
|
||||
|
||||
def build_run_event_repository(session: AsyncSession) -> RunEventRepositoryProtocol:
|
||||
return DbRunEventRepository(session)
|
||||
|
||||
|
||||
def build_user_repository(session: AsyncSession) -> UserRepositoryProtocol:
|
||||
return DbUserRepository(session)
|
||||
@@ -0,0 +1,7 @@
|
||||
from store.repositories.models.feedback import Feedback
|
||||
from store.repositories.models.run import Run
|
||||
from store.repositories.models.run_event import RunEvent
|
||||
from store.repositories.models.thread_meta import ThreadMeta
|
||||
from store.repositories.models.user import User
|
||||
|
||||
__all__ = ["Feedback", "Run", "RunEvent", "ThreadMeta", "User"]
|
||||
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
|
||||
|
||||
class Feedback(DataClassBase):
|
||||
"""Feedback table (create-only, no updated_time)."""
|
||||
|
||||
__tablename__ = "feedback"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
||||
{"comment": "Feedback table."},
|
||||
)
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
rating: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
comment: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -0,0 +1,66 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Index, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
|
||||
|
||||
class Run(DataClassBase):
|
||||
"""Run metadata table."""
|
||||
|
||||
__tablename__ = "runs"
|
||||
__table_args__ = (
|
||||
Index("ix_runs_thread_status", "thread_id", "status"),
|
||||
{"comment": "Run metadata table."},
|
||||
)
|
||||
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
|
||||
model_name: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
|
||||
error: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
|
||||
kwargs: Mapped[dict[str, Any]] = mapped_column("kwargs_json", JSON, default_factory=dict)
|
||||
|
||||
total_input_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_output_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
llm_call_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
lead_agent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
subagent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
middleware_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
message_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
first_human_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
last_ai_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
"updated_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=_tz.now,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Index, Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
|
||||
|
||||
class RunEvent(DataClassBase):
|
||||
"""Run event table."""
|
||||
|
||||
__tablename__ = "run_events"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
|
||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
||||
{"comment": "Run event table."},
|
||||
)
|
||||
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), index=True)
|
||||
category: Mapped[str] = mapped_column(String(16), index=True)
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
seq: Mapped[int] = mapped_column(Integer, default=0, index=True)
|
||||
content: Mapped[str] = mapped_column(UniversalText, default="")
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
sort_order=999,
|
||||
comment="Event timestamp",
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
|
||||
|
||||
class ThreadMeta(DataClassBase):
|
||||
"""Thread metadata table."""
|
||||
|
||||
__tablename__ = "threads_meta"
|
||||
__table_args__ = {"comment": "Thread metadata table."}
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None, index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256), default=None)
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle", index=True)
|
||||
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
"updated_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=_tz.now,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
@@ -0,0 +1,45 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Index, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
|
||||
|
||||
class User(DataClassBase):
|
||||
"""User account table."""
|
||||
|
||||
__tablename__ = "users"
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"idx_users_oauth_identity",
|
||||
"oauth_provider",
|
||||
"oauth_id",
|
||||
unique=True,
|
||||
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
|
||||
),
|
||||
{"comment": "User account table."},
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
|
||||
system_role: Mapped[str] = mapped_column(String(16), default="user")
|
||||
|
||||
password_hash: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
oauth_provider: Mapped[str | None] = mapped_column(String(32), default=None)
|
||||
oauth_id: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
needs_setup: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
token_version: Mapped[int] = mapped_column(default=0)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .timezone import get_timezone
|
||||
|
||||
__all__ = ["get_timezone"]
|
||||
@@ -0,0 +1,51 @@
|
||||
import zoneinfo
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from store.config.app_config import get_app_config
|
||||
|
||||
# IANA identifiers that map to UTC — see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
|
||||
_UTC_IDENTIFIERS = frozenset({"Etc/UCT", "Etc/Universal", "Etc/UTC", "Etc/Zulu", "UCT", "Universal", "UTC", "Zulu"})
|
||||
|
||||
|
||||
class TimeZone:
|
||||
def __init__(self) -> None:
|
||||
app_config = get_app_config()
|
||||
if app_config.timezone in _UTC_IDENTIFIERS:
|
||||
self.tz_info = UTC
|
||||
else:
|
||||
self.tz_info = zoneinfo.ZoneInfo(app_config.timezone)
|
||||
|
||||
def now(self) -> datetime:
|
||||
"""Return the current time in the configured timezone."""
|
||||
return datetime.now(self.tz_info)
|
||||
|
||||
def from_datetime(self, t: datetime) -> datetime:
|
||||
"""Convert a datetime to the configured timezone."""
|
||||
return t.astimezone(self.tz_info)
|
||||
|
||||
def from_str(self, t_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime:
|
||||
"""Parse a time string and attach the configured timezone."""
|
||||
return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info)
|
||||
|
||||
@staticmethod
|
||||
def to_str(t: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""Format a datetime to string."""
|
||||
return t.strftime(format_str)
|
||||
|
||||
@staticmethod
|
||||
def to_utc(t: datetime | int) -> datetime:
|
||||
"""Convert a datetime or Unix timestamp to UTC."""
|
||||
if isinstance(t, datetime):
|
||||
return t.astimezone(UTC)
|
||||
return datetime.fromtimestamp(t, tz=UTC)
|
||||
|
||||
|
||||
_timezone = None
|
||||
|
||||
|
||||
def get_timezone() -> TimeZone:
|
||||
"""Return the global TimeZone singleton (lazy-initialized)."""
|
||||
global _timezone
|
||||
if _timezone is None:
|
||||
_timezone = TimeZone()
|
||||
return _timezone
|
||||
@@ -6,6 +6,7 @@ readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"deerflow-harness",
|
||||
"deerflow-storage",
|
||||
"fastapi>=0.115.0",
|
||||
"httpx>=0.28.0",
|
||||
"python-multipart>=0.0.27",
|
||||
@@ -24,7 +25,7 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
postgres = ["deerflow-harness[postgres]"]
|
||||
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
@@ -43,7 +44,8 @@ markers = [
|
||||
index-url = "https://pypi.org/simple"
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["packages/harness"]
|
||||
members = ["packages/harness", "packages/storage"]
|
||||
|
||||
[tool.uv.sources]
|
||||
deerflow-harness = { workspace = true }
|
||||
deerflow-storage = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.config.storage_config import StorageConfig
|
||||
from store.persistence.factory import _create_database_url, storage_config_from_database_config
|
||||
|
||||
|
||||
def test_database_sqlite_config_maps_to_storage_config(tmp_path):
|
||||
database = SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=True,
|
||||
pool_size=9,
|
||||
)
|
||||
|
||||
storage = storage_config_from_database_config(database)
|
||||
|
||||
assert storage == StorageConfig(
|
||||
driver="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=True,
|
||||
pool_size=9,
|
||||
)
|
||||
assert storage.sqlite_storage_path == str(tmp_path / "deerflow.db")
|
||||
|
||||
|
||||
def test_database_memory_config_is_not_a_storage_backend():
|
||||
database = SimpleNamespace(backend="memory")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_database_postgres_config_preserves_url_and_pool_options():
|
||||
database = SimpleNamespace(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql://user:pass@db.example:5544/deerflow",
|
||||
echo_sql=True,
|
||||
pool_size=11,
|
||||
)
|
||||
|
||||
storage = storage_config_from_database_config(database)
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert storage.driver == "postgres"
|
||||
assert storage.database_url == "postgresql://user:pass@db.example:5544/deerflow"
|
||||
assert storage.username == "user"
|
||||
assert storage.password == "pass"
|
||||
assert storage.host == "db.example"
|
||||
assert storage.port == 5544
|
||||
assert storage.db_name == "deerflow"
|
||||
assert storage.echo_sql is True
|
||||
assert storage.pool_size == 11
|
||||
assert url.drivername == "postgresql+asyncpg"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_database_postgres_requires_url():
|
||||
database = SimpleNamespace(backend="postgres", postgres_url="")
|
||||
|
||||
with pytest.raises(ValueError, match="database.postgres_url is required"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_unsupported_database_backend_rejected():
|
||||
database = SimpleNamespace(backend="oracle")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import UserCreate, build_user_repository
|
||||
|
||||
|
||||
def test_sqlite_persistence_from_database_config_creates_storage_tables(tmp_path):
|
||||
async def run() -> None:
|
||||
persistence = await create_persistence_from_database_config(
|
||||
SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=False,
|
||||
pool_size=5,
|
||||
)
|
||||
)
|
||||
assert persistence is not None
|
||||
try:
|
||||
await persistence.setup()
|
||||
|
||||
async with persistence.engine.connect() as conn:
|
||||
tables = await conn.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names()))
|
||||
|
||||
assert {
|
||||
"users",
|
||||
"runs",
|
||||
"run_events",
|
||||
"threads_meta",
|
||||
"feedback",
|
||||
}.issubset(tables)
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email="storage-user@example.com",
|
||||
password_hash="hash",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
assert await repo.get_user_by_id(user.id) == user
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
asyncio.run(run())
|
||||
@@ -0,0 +1,312 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import (
|
||||
FeedbackCreate,
|
||||
RunCreate,
|
||||
RunEventCreate,
|
||||
ThreadMetaCreate,
|
||||
build_feedback_repository,
|
||||
build_run_event_repository,
|
||||
build_run_repository,
|
||||
build_thread_meta_repository,
|
||||
)
|
||||
|
||||
|
||||
async def _make_persistence(tmp_path):
|
||||
persistence = await create_persistence_from_database_config(
|
||||
SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=False,
|
||||
pool_size=5,
|
||||
)
|
||||
)
|
||||
await persistence.setup()
|
||||
return persistence
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_run_repository_filters_and_aggregates(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
old = datetime.now(UTC) - timedelta(hours=1)
|
||||
newer = datetime.now(UTC)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id="run-old",
|
||||
thread_id="thread-1",
|
||||
user_id="alice",
|
||||
status="pending",
|
||||
model_name="model-a",
|
||||
metadata={"kind": "draft"},
|
||||
kwargs={"temperature": 0.2},
|
||||
created_time=old,
|
||||
)
|
||||
)
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id="run-new",
|
||||
thread_id="thread-1",
|
||||
user_id="bob",
|
||||
status="running",
|
||||
model_name="model-b",
|
||||
error="queued",
|
||||
created_time=newer,
|
||||
)
|
||||
)
|
||||
await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running"))
|
||||
await repo.update_run_completion(
|
||||
"run-old",
|
||||
status="success",
|
||||
total_input_tokens=7,
|
||||
total_output_tokens=3,
|
||||
total_tokens=10,
|
||||
llm_call_count=1,
|
||||
lead_agent_tokens=8,
|
||||
subagent_tokens=2,
|
||||
first_human_message="hello",
|
||||
last_ai_message="world",
|
||||
)
|
||||
await repo.update_run_completion(
|
||||
"run-new",
|
||||
status="error",
|
||||
total_tokens=5,
|
||||
middleware_tokens=5,
|
||||
error="failed",
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
fetched = await repo.get_run("run-old")
|
||||
assert fetched is not None
|
||||
assert fetched.metadata == {"kind": "draft"}
|
||||
assert fetched.kwargs == {"temperature": 0.2}
|
||||
assert fetched.first_human_message == "hello"
|
||||
assert fetched.last_ai_message == "world"
|
||||
|
||||
all_thread_runs = await repo.list_runs_by_thread("thread-1")
|
||||
assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"]
|
||||
alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice")
|
||||
assert [run.run_id for run in alice_runs] == ["run-old"]
|
||||
|
||||
pending = await repo.list_pending(before=datetime.now(UTC).isoformat())
|
||||
assert [run.run_id for run in pending] == []
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("thread-1")
|
||||
assert agg["total_tokens"] == 15
|
||||
assert agg["total_input_tokens"] == 7
|
||||
assert agg["total_output_tokens"] == 3
|
||||
assert agg["total_runs"] == 2
|
||||
assert agg["by_model"] == {
|
||||
"model-a": {"tokens": 10, "runs": 1},
|
||||
"model-b": {"tokens": 5, "runs": 1},
|
||||
}
|
||||
assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5}
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id="thread-1",
|
||||
assistant_id="agent-a",
|
||||
user_id="alice",
|
||||
display_name="Initial",
|
||||
status="idle",
|
||||
metadata={"topic": "finance", "region": "cn"},
|
||||
)
|
||||
)
|
||||
await repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id="thread-2",
|
||||
assistant_id="agent-b",
|
||||
user_id="bob",
|
||||
status="running",
|
||||
metadata={"topic": "legal"},
|
||||
)
|
||||
)
|
||||
await repo.update_thread_meta(
|
||||
"thread-1",
|
||||
display_name="Updated",
|
||||
status="running",
|
||||
metadata={"topic": "finance", "region": "us"},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
fetched = await repo.get_thread_meta("thread-1")
|
||||
assert fetched is not None
|
||||
assert fetched.display_name == "Updated"
|
||||
assert fetched.status == "running"
|
||||
assert fetched.metadata == {"topic": "finance", "region": "us"}
|
||||
|
||||
by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice")
|
||||
assert [thread.thread_id for thread in by_metadata] == ["thread-1"]
|
||||
by_assistant = await repo.search_threads(assistant_id="agent-b")
|
||||
assert [thread.thread_id for thread in by_assistant] == ["thread-2"]
|
||||
|
||||
await repo.delete_thread("thread-1")
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
assert await repo.get_thread_meta("thread-1") is None
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
first = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-1",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=1,
|
||||
user_id="alice",
|
||||
message_id="msg-1",
|
||||
comment="good",
|
||||
)
|
||||
)
|
||||
second = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-2",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=-1,
|
||||
user_id="bob",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
assert await repo.get_feedback(first.feedback_id) == first
|
||||
assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [
|
||||
second.feedback_id,
|
||||
first.feedback_id,
|
||||
]
|
||||
assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == {
|
||||
"fb-1",
|
||||
"fb-2",
|
||||
}
|
||||
assert await repo.delete_feedback("fb-1") is True
|
||||
assert await repo.delete_feedback("missing") is False
|
||||
with pytest.raises(ValueError, match="rating must be"):
|
||||
await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-bad",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=0,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
assert await repo.get_feedback("fb-1") is None
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.append_batch(
|
||||
[
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
user_id="alice",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content={"role": "user", "content": "hello"},
|
||||
metadata={"source": "input"},
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
event_type="tool",
|
||||
category="debug",
|
||||
content="tool-call",
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-2",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="second",
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-2",
|
||||
run_id="run-3",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="other-thread",
|
||||
),
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert [(row.thread_id, row.seq) for row in rows] == [
|
||||
("thread-1", 1),
|
||||
("thread-1", 2),
|
||||
("thread-1", 3),
|
||||
("thread-2", 1),
|
||||
]
|
||||
assert rows[0].content == {"role": "user", "content": "hello"}
|
||||
assert rows[0].metadata == {"source": "input", "content_is_json": True}
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
messages = await repo.list_messages("thread-1", limit=2)
|
||||
assert [event.seq for event in messages] == [1, 3]
|
||||
assert await repo.count_messages("thread-1") == 2
|
||||
|
||||
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
|
||||
assert [event.seq for event in after] == [1]
|
||||
before = await repo.list_messages("thread-1", before_seq=3, limit=5)
|
||||
assert [event.seq for event in before] == [1]
|
||||
|
||||
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
|
||||
assert [event.content for event in events] == ["tool-call"]
|
||||
|
||||
assert await repo.delete_by_run("thread-1", "run-1") == 2
|
||||
assert await repo.delete_by_thread("thread-2") == 1
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
remaining = await repo.list_events("thread-1", "run-2")
|
||||
assert [event.seq for event in remaining] == [3]
|
||||
assert await repo.count_messages("thread-2") == 0
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
@@ -0,0 +1,178 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from store.persistence import MappedBase
|
||||
from store.repositories import UserCreate, UserNotFoundError, build_user_repository
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]:
|
||||
db_path = tmp_path / "storage-users.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
try:
|
||||
yield async_sessionmaker(engine, expire_on_commit=False)
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def _create_user(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
*,
|
||||
email: str = "user@example.com",
|
||||
system_role: str = "user",
|
||||
oauth_provider: str | None = None,
|
||||
oauth_id: str | None = None,
|
||||
):
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email=email,
|
||||
password_hash="hash",
|
||||
system_role=system_role, # type: ignore[arg-type]
|
||||
oauth_provider=oauth_provider,
|
||||
oauth_id=oauth_id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def test_create_and_get_user_by_id_and_email(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
created = await _create_user(session_factory)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
by_id = await repo.get_user_by_id(created.id)
|
||||
by_email = await repo.get_user_by_email(created.email)
|
||||
|
||||
assert by_id == created
|
||||
assert by_email == created
|
||||
assert created.system_role == "user"
|
||||
assert created.needs_setup is False
|
||||
assert created.token_version == 0
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_duplicate_email_raises_value_error(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="dupe@example.com")
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
with pytest.raises(ValueError, match="Email already registered"):
|
||||
await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email="dupe@example.com",
|
||||
password_hash="hash",
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_oauth_lookup_and_plain_users_without_oauth(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="local-1@example.com")
|
||||
await _create_user(session_factory, email="local-2@example.com")
|
||||
oauth_user = await _create_user(
|
||||
session_factory,
|
||||
email="oauth@example.com",
|
||||
oauth_provider="github",
|
||||
oauth_id="gh-123",
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
assert await repo.count_users() == 3
|
||||
assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user
|
||||
assert await repo.get_user_by_oauth("github", "missing") is None
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_count_admins_and_get_first_admin(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="user@example.com")
|
||||
admin = await _create_user(
|
||||
session_factory,
|
||||
email="admin@example.com",
|
||||
system_role="admin",
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
assert await repo.count_users() == 2
|
||||
assert await repo.count_admin_users() == 1
|
||||
assert await repo.get_first_admin() == admin
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_update_user_round_trips_token_version_and_setup_state(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
created = await _create_user(session_factory)
|
||||
updated = created.model_copy(
|
||||
update={
|
||||
"email": "renamed@example.com",
|
||||
"token_version": 4,
|
||||
"needs_setup": True,
|
||||
}
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
saved = await repo.update_user(updated)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
fetched = await repo.get_user_by_id(created.id)
|
||||
|
||||
assert saved.email == "renamed@example.com"
|
||||
assert fetched == updated
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_update_missing_user_raises(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
missing = UserCreate(id=str(uuid4()), email="missing@example.com")
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
created_shape = await repo.create_user(missing)
|
||||
await session.rollback()
|
||||
|
||||
with pytest.raises(UserNotFoundError):
|
||||
await repo.update_user(created_shape)
|
||||
|
||||
asyncio.run(run())
|
||||
Generated
+88
@@ -14,6 +14,7 @@ resolution-markers = [
|
||||
members = [
|
||||
"deer-flow",
|
||||
"deerflow-harness",
|
||||
"deerflow-storage",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -136,6 +137,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441, upload-time = "2026-03-31T22:00:12.791Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiomysql"
|
||||
version = "0.3.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pymysql" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/29/e0/302aeffe8d90853556f47f3106b89c16cc2ec2a4d269bdfd82e3f4ae12cc/aiomysql-0.3.2.tar.gz", hash = "sha256:72d15ef5cfc34c03468eb41e1b90adb9fd9347b0b589114bd23ead569a02ac1a", size = 108311, upload-time = "2025-10-22T00:15:21.278Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/af/aae0153c3e28712adaf462328f6c7a3c196a1c1c27b491de4377dd3e6b52/aiomysql-0.3.2-py3-none-any.whl", hash = "sha256:c82c5ba04137d7afd5c693a258bea8ead2aad77101668044143a991e04632eb2", size = 71834, upload-time = "2025-10-22T00:15:15.905Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosignal"
|
||||
version = "1.4.0"
|
||||
@@ -746,6 +759,7 @@ source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "bcrypt" },
|
||||
{ name = "deerflow-harness" },
|
||||
{ name = "deerflow-storage" },
|
||||
{ name = "dingtalk-stream" },
|
||||
{ name = "email-validator" },
|
||||
{ name = "fastapi" },
|
||||
@@ -765,6 +779,7 @@ dependencies = [
|
||||
[package.optional-dependencies]
|
||||
postgres = [
|
||||
{ name = "deerflow-harness", extra = ["postgres"] },
|
||||
{ name = "deerflow-storage", extra = ["postgres"] },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
@@ -780,6 +795,8 @@ requires-dist = [
|
||||
{ name = "bcrypt", specifier = ">=4.0.0" },
|
||||
{ name = "deerflow-harness", editable = "packages/harness" },
|
||||
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
||||
{ name = "deerflow-storage", editable = "packages/storage" },
|
||||
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
|
||||
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
||||
{ name = "email-validator", specifier = ">=2.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.115.0" },
|
||||
@@ -901,6 +918,54 @@ requires-dist = [
|
||||
]
|
||||
provides-extras = ["ollama", "postgres", "pymupdf"]
|
||||
|
||||
[[package]]
|
||||
name = "deerflow-storage"
|
||||
version = "0.1.0"
|
||||
source = { editable = "packages/storage" }
|
||||
dependencies = [
|
||||
{ name = "alembic" },
|
||||
{ name = "dotenv" },
|
||||
{ name = "langgraph" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
mysql = [
|
||||
{ name = "aiomysql" },
|
||||
{ name = "langgraph-checkpoint-mysql" },
|
||||
]
|
||||
postgres = [
|
||||
{ name = "asyncpg" },
|
||||
{ name = "langgraph-checkpoint-postgres" },
|
||||
{ name = "psycopg", extra = ["binary"] },
|
||||
{ name = "psycopg-pool" },
|
||||
]
|
||||
sqlite = [
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "langgraph-checkpoint-sqlite" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiomysql", marker = "extra == 'mysql'", specifier = ">=0.2" },
|
||||
{ name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.22.1" },
|
||||
{ name = "alembic", specifier = ">=1.13" },
|
||||
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
|
||||
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||
{ name = "langgraph", specifier = ">=1.1.9" },
|
||||
{ name = "langgraph-checkpoint-mysql", marker = "extra == 'mysql'", specifier = ">=3.0.0" },
|
||||
{ name = "langgraph-checkpoint-postgres", marker = "extra == 'postgres'", specifier = ">=3.0.5" },
|
||||
{ name = "langgraph-checkpoint-sqlite", marker = "extra == 'sqlite'", specifier = ">=3.0.3" },
|
||||
{ name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.3.3" },
|
||||
{ name = "psycopg-pool", marker = "extra == 'postgres'", specifier = ">=3.3.0" },
|
||||
{ name = "pydantic", specifier = ">=2.12.5" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.3" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<3.0" },
|
||||
]
|
||||
provides-extras = ["postgres", "mysql", "sqlite"]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
@@ -1914,6 +1979,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/5a/6dba29dd89b0a46ae21c707da0f9d17e94f27d3e481ed15bc99d6bd20aa6/langgraph_checkpoint-4.0.2-py3-none-any.whl", hash = "sha256:59b0f29216128a629c58dd07c98aa004f82f51805d5573126ffb419b753ff253", size = 51000, upload-time = "2026-04-15T21:02:59.096Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint-mysql"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langgraph-checkpoint" },
|
||||
{ name = "orjson" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e0/4e/0a6c78e5d3f2ca1525903c2363e721873594b6b77dd83537a6369193c474/langgraph_checkpoint_mysql-3.0.0.tar.gz", hash = "sha256:006aaa089f4c2fbd7b2c113b800ccd3dbb95f92203e656451677256b4b4f880f", size = 213142, upload-time = "2026-01-23T11:11:15.74Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/08/68/343103a7fae05523f9cecabbec2babdb737e66b4bf6ea48ae00c685ed11c/langgraph_checkpoint_mysql-3.0.0-py3-none-any.whl", hash = "sha256:7560ccd16e7596a047e15a307cec12dbd88fdcaab45a75759e5c6adef22a27d1", size = 38009, upload-time = "2026-01-23T11:11:14.697Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint-postgres"
|
||||
version = "3.0.5"
|
||||
@@ -3442,6 +3521,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pymysql"
|
||||
version = "1.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7f/ec/8d45c920e90445f0b75c590b32851853ed319763b0d8dff8d283052da8cf/pymysql-1.1.3.tar.gz", hash = "sha256:e70ebf2047a4edf6138cf79c68ad418ef620af65900aa585c5e8bfc95044d43a", size = 48207, upload-time = "2026-05-01T09:09:54.532Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/dc/9085f3d6f497e9b25fb40d6e8ecef3ddbb5cf977a949b933624a299f5c16/pymysql-1.1.3-py3-none-any.whl", hash = "sha256:8164ba62c552f6105f3b11753352d0f16b90d1703ba67d81923d5a8a5d1c5289", size = 45356, upload-time = "2026-05-01T09:09:53.316Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pypdfium2"
|
||||
version = "5.7.1"
|
||||
|
||||
Reference in New Issue
Block a user