fix(storage): address repository review feedback

This commit is contained in:
rayhpeng
2026-05-13 12:51:45 +08:00
parent d3066a1746
commit 11a9041b65
13 changed files with 140 additions and 65 deletions
@@ -34,6 +34,8 @@ class RunEvent(BaseModel):
class RunEventRepositoryProtocol(Protocol):
# Sequence values are time-ordered integer cursors. The application layer
# owns the single-writer invariant for a thread while a run is active.
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
async def list_messages(
@@ -1,6 +1,9 @@
from __future__ import annotations
import json
import secrets
import threading
import time
from typing import Any
from sqlalchemy import delete, func, select
@@ -9,6 +12,26 @@ from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
from store.repositories.models.run_event import RunEvent as RunEventModel
_SEQ_COUNTER_BITS = 12
_SEQ_PROCESS_BITS = 9
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
_last_seq_millis = 0
_seq_lock = threading.Lock()
def _allocate_sequence_base(batch_size: int) -> int:
if batch_size >= _SEQ_COUNTER_LIMIT:
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
global _last_seq_millis
now_ms = time.time_ns() // 1_000_000
with _seq_lock:
seq_ms = max(now_ms, _last_seq_millis + 1)
_last_seq_millis = seq_ms
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
if not isinstance(content, str):
@@ -52,28 +75,17 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
if not events:
return []
thread_ids = {event.thread_id for event in events}
seq_by_thread: dict[str, int] = {}
for thread_id in thread_ids:
max_seq = await self._session.scalar(
select(RunEventModel.seq)
.where(RunEventModel.thread_id == thread_id)
.order_by(RunEventModel.seq.desc())
.limit(1)
.with_for_update()
)
seq_by_thread[thread_id] = max_seq or 0
seq_base = _allocate_sequence_base(len(events))
rows: list[RunEventModel] = []
for event in events:
seq_by_thread[event.thread_id] += 1
for index, event in enumerate(events, start=1):
content, metadata = _serialize_content(event.content, dict(event.metadata))
row = RunEventModel(
thread_id=event.thread_id,
run_id=event.run_id,
user_id=event.user_id,
seq=seq_by_thread[event.thread_id],
seq=seq_base + index,
event_type=event.event_type,
category=event.category,
content=content,
@@ -5,10 +5,7 @@ from datetime import datetime
from sqlalchemy import Integer, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Feedback(DataClassBase):
@@ -33,7 +30,7 @@ class Feedback(DataClassBase):
"created_at",
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -6,10 +6,7 @@ from typing import Any
from sqlalchemy import JSON, Index, Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Run(DataClassBase):
@@ -51,7 +48,7 @@ class Run(DataClassBase):
"created_at",
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -60,7 +57,7 @@ class Run(DataClassBase):
TimeZone,
init=False,
default=None,
onupdate=_tz.now,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -3,13 +3,16 @@ from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, Index, Integer, String, UniqueConstraint
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import (
DataClassBase,
TimeZone,
UniversalText,
current_time,
id_key,
)
class RunEvent(DataClassBase):
@@ -31,13 +34,13 @@ class RunEvent(DataClassBase):
category: Mapped[str] = mapped_column(String(16), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
seq: Mapped[int] = mapped_column(Integer, default=0, index=True)
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
content: Mapped[str] = mapped_column(UniversalText, default="")
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Event timestamp",
)
@@ -6,10 +6,7 @@ from typing import Any
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class ThreadMeta(DataClassBase):
@@ -31,7 +28,7 @@ class ThreadMeta(DataClassBase):
"created_at",
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -40,7 +37,7 @@ class ThreadMeta(DataClassBase):
TimeZone,
init=False,
default=None,
onupdate=_tz.now,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -5,10 +5,7 @@ from datetime import datetime
from sqlalchemy import Boolean, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class User(DataClassBase):
@@ -39,7 +36,7 @@ class User(DataClassBase):
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)