fix(storage): address repository review feedback

This commit is contained in:
rayhpeng
2026-05-13 12:51:45 +08:00
parent d3066a1746
commit 11a9041b65
13 changed files with 140 additions and 65 deletions
@@ -1,6 +1,9 @@
from __future__ import annotations
import json
import secrets
import threading
import time
from typing import Any
from sqlalchemy import delete, func, select
@@ -9,6 +12,26 @@ from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
from store.repositories.models.run_event import RunEvent as RunEventModel
_SEQ_COUNTER_BITS = 12
_SEQ_PROCESS_BITS = 9
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
_last_seq_millis = 0
_seq_lock = threading.Lock()
def _allocate_sequence_base(batch_size: int) -> int:
if batch_size >= _SEQ_COUNTER_LIMIT:
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
global _last_seq_millis
now_ms = time.time_ns() // 1_000_000
with _seq_lock:
seq_ms = max(now_ms, _last_seq_millis + 1)
_last_seq_millis = seq_ms
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
if not isinstance(content, str):
@@ -52,28 +75,17 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
if not events:
return []
thread_ids = {event.thread_id for event in events}
seq_by_thread: dict[str, int] = {}
for thread_id in thread_ids:
max_seq = await self._session.scalar(
select(RunEventModel.seq)
.where(RunEventModel.thread_id == thread_id)
.order_by(RunEventModel.seq.desc())
.limit(1)
.with_for_update()
)
seq_by_thread[thread_id] = max_seq or 0
seq_base = _allocate_sequence_base(len(events))
rows: list[RunEventModel] = []
for event in events:
seq_by_thread[event.thread_id] += 1
for index, event in enumerate(events, start=1):
content, metadata = _serialize_content(event.content, dict(event.metadata))
row = RunEventModel(
thread_id=event.thread_id,
run_id=event.run_id,
user_id=event.user_id,
seq=seq_by_thread[event.thread_id],
seq=seq_base + index,
event_type=event.event_type,
category=event.category,
content=content,