mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
fix(storage): harden sql persistence compatibility
This commit is contained in:
@@ -153,9 +153,10 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = RunModel.status.in_(("success", "error"))
|
||||
model_expr = func.coalesce(RunModel.model_name, "unknown")
|
||||
stmt = (
|
||||
select(
|
||||
func.coalesce(RunModel.model_name, "unknown").label("model"),
|
||||
model_expr.label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
|
||||
@@ -165,7 +166,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(RunModel.thread_id == thread_id, completed)
|
||||
.group_by(func.coalesce(RunModel.model_name, "unknown"))
|
||||
.group_by(model_expr)
|
||||
)
|
||||
|
||||
rows = (await self._session.execute(stmt)).all()
|
||||
|
||||
@@ -56,8 +56,10 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
seq_by_thread: dict[str, int] = {}
|
||||
for thread_id in thread_ids:
|
||||
max_seq = await self._session.scalar(
|
||||
select(func.max(RunEventModel.seq))
|
||||
select(RunEventModel.seq)
|
||||
.where(RunEventModel.thread_id == thread_id)
|
||||
.order_by(RunEventModel.seq.desc())
|
||||
.limit(1)
|
||||
.with_for_update()
|
||||
)
|
||||
seq_by_thread[thread_id] = max_seq or 0
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol
|
||||
from store.persistence.json_compat import json_match
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
InvalidMetadataFilterError,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||
return ThreadMeta(
|
||||
@@ -87,10 +96,18 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
if assistant_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||
if metadata:
|
||||
applied = 0
|
||||
for key, value in metadata.items():
|
||||
stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value))
|
||||
try:
|
||||
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
|
||||
applied += 1
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
|
||||
if applied == 0:
|
||||
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
|
||||
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
|
||||
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc())
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
|
||||
Reference in New Issue
Block a user