fix(storage): address repository review feedback

This commit is contained in:
rayhpeng
2026-05-13 12:51:45 +08:00
parent d3066a1746
commit 11a9041b65
13 changed files with 140 additions and 65 deletions
@@ -6,20 +6,17 @@ from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
from store.common import DataBaseType
from store.config.app_config import get_app_config
from store.utils import get_timezone
timezone = get_timezone()
app_config = get_app_config()
# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT)
_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger
def current_time() -> datetime:
return get_timezone().now()
id_key = Annotated[
int,
mapped_column(
_id_type,
BigInteger().with_variant(Integer, "sqlite"),
primary_key=True,
unique=True,
index=True,
@@ -33,9 +30,14 @@ id_key = Annotated[
class UniversalText(TypeDecorator[str]):
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text
impl = Text
cache_ok = True
def load_dialect_impl(self, dialect): # noqa: ANN001
if dialect.name == "mysql":
return dialect.type_descriptor(LONGTEXT())
return dialect.type_descriptor(Text())
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
return value
@@ -54,11 +56,13 @@ class TimeZone(TypeDecorator[datetime]):
return datetime
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.utcoffset() != timezone.now().utcoffset():
value = timezone.from_datetime(value)
return value
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.tzinfo is None:
value = value.replace(tzinfo=timezone.tz_info)
return value
@@ -70,14 +74,14 @@ class DateTimeMixin(MappedAsDataclass):
created_time: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=timezone.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
TimeZone,
init=False,
onupdate=timezone.now,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -70,6 +70,8 @@ def _create_database_url(storage_config: StorageConfig) -> URL:
url = make_url(storage_config.database_url)
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
url = url.set(drivername="postgresql+asyncpg")
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
url = url.set(drivername="mysql+aiomysql")
else:
url = URL.create(
drivername=driver,