fix(storage): harden sql persistence compatibility

This commit is contained in:
rayhpeng
2026-05-13 11:26:25 +08:00
parent 485f8a2bf2
commit d3066a1746
11 changed files with 398 additions and 9 deletions
@@ -153,9 +153,10 @@ class DbRunRepository(RunRepositoryProtocol):
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = RunModel.status.in_(("success", "error"))
model_expr = func.coalesce(RunModel.model_name, "unknown")
stmt = (
select(
func.coalesce(RunModel.model_name, "unknown").label("model"),
model_expr.label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
@@ -165,7 +166,7 @@ class DbRunRepository(RunRepositoryProtocol):
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
)
.where(RunModel.thread_id == thread_id, completed)
.group_by(func.coalesce(RunModel.model_name, "unknown"))
.group_by(model_expr)
)
rows = (await self._session.execute(stmt)).all()
@@ -56,8 +56,10 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
seq_by_thread: dict[str, int] = {}
for thread_id in thread_ids:
max_seq = await self._session.scalar(
select(func.max(RunEventModel.seq))
select(RunEventModel.seq)
.where(RunEventModel.thread_id == thread_id)
.order_by(RunEventModel.seq.desc())
.limit(1)
.with_for_update()
)
seq_by_thread[thread_id] = max_seq or 0
@@ -1,13 +1,22 @@
from __future__ import annotations
import logging
from typing import Any
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol
from store.persistence.json_compat import json_match
from store.repositories.contracts.thread_meta import (
InvalidMetadataFilterError,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
)
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
logger = logging.getLogger(__name__)
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
return ThreadMeta(
@@ -87,10 +96,18 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
if assistant_id is not None:
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
if metadata:
applied = 0
for key, value in metadata.items():
stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value))
try:
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
applied += 1
except (ValueError, TypeError) as exc:
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
if applied == 0:
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
stmt = stmt.order_by(ThreadMetaModel.created_time.desc())
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
stmt = stmt.limit(limit).offset(offset)
result = await self._session.execute(stmt)