diff --git a/backend/packages/storage/store/repositories/contracts/feedback.py b/backend/packages/storage/store/repositories/contracts/feedback.py index 0ad0486ea..6f4b77c24 100644 --- a/backend/packages/storage/store/repositories/contracts/feedback.py +++ b/backend/packages/storage/store/repositories/contracts/feedback.py @@ -39,9 +39,15 @@ class FeedbackAggregate(TypedDict): class FeedbackRepositoryProtocol(Protocol): - async def create_feedback(self, data: FeedbackCreate) -> Feedback: ... - async def upsert_feedback(self, data: FeedbackCreate) -> Feedback: ... - async def get_feedback(self, feedback_id: str) -> Feedback | None: ... + async def create_feedback(self, data: FeedbackCreate) -> Feedback: + pass + + async def upsert_feedback(self, data: FeedbackCreate) -> Feedback: + pass + + async def get_feedback(self, feedback_id: str) -> Feedback | None: + pass + async def list_feedback_by_run( self, run_id: str, @@ -49,14 +55,23 @@ class FeedbackRepositoryProtocol(Protocol): thread_id: str | None = None, user_id: str | None = None, limit: int | None = None, - ) -> list[Feedback]: ... + ) -> list[Feedback]: + pass + async def list_feedback_by_thread( self, thread_id: str, *, user_id: str | None = None, limit: int | None = None, - ) -> list[Feedback]: ... - async def delete_feedback(self, feedback_id: str) -> bool: ... - async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool: ... - async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate: ... + ) -> list[Feedback]: + pass + + async def delete_feedback(self, feedback_id: str) -> bool: + pass + + async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool: + pass + + async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate: + pass diff --git a/backend/packages/storage/store/repositories/contracts/run.py b/backend/packages/storage/store/repositories/contracts/run.py index a14995dc9..448e1536b 100644 --- a/backend/packages/storage/store/repositories/contracts/run.py +++ b/backend/packages/storage/store/repositories/contracts/run.py @@ -52,8 +52,12 @@ class Run(BaseModel): class RunRepositoryProtocol(Protocol): - async def create_run(self, data: RunCreate) -> Run: ... - async def get_run(self, run_id: str) -> Run | None: ... + async def create_run(self, data: RunCreate) -> Run: + pass + + async def get_run(self, run_id: str) -> Run | None: + pass + async def list_runs_by_thread( self, thread_id: str, @@ -61,10 +65,18 @@ class RunRepositoryProtocol(Protocol): user_id: str | None = None, limit: int = 50, offset: int = 0, - ) -> list[Run]: ... - async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None: ... - async def delete_run(self, run_id: str) -> None: ... - async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]: ... + ) -> list[Run]: + pass + + async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None: + pass + + async def delete_run(self, run_id: str) -> None: + pass + + async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]: + pass + async def update_run_completion( self, run_id: str, @@ -81,5 +93,8 @@ class RunRepositoryProtocol(Protocol): first_human_message: str | None = None, last_ai_message: str | None = None, error: str | None = None, - ) -> None: ... - async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: ... + ) -> None: + pass + + async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: + pass diff --git a/backend/packages/storage/store/repositories/contracts/run_event.py b/backend/packages/storage/store/repositories/contracts/run_event.py index d0cb11aa3..195d11f25 100644 --- a/backend/packages/storage/store/repositories/contracts/run_event.py +++ b/backend/packages/storage/store/repositories/contracts/run_event.py @@ -36,7 +36,8 @@ class RunEvent(BaseModel): class RunEventRepositoryProtocol(Protocol): # Sequence values are time-ordered integer cursors. The application layer # owns the single-writer invariant for a thread while a run is active. - async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ... + async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: + pass async def list_messages( self, @@ -46,7 +47,8 @@ class RunEventRepositoryProtocol(Protocol): before_seq: int | None = None, after_seq: int | None = None, user_id: str | None = None, - ) -> list[RunEvent]: ... + ) -> list[RunEvent]: + pass async def list_events( self, @@ -56,7 +58,8 @@ class RunEventRepositoryProtocol(Protocol): event_types: list[str] | None = None, limit: int = 500, user_id: str | None = None, - ) -> list[RunEvent]: ... + ) -> list[RunEvent]: + pass async def list_messages_by_run( self, @@ -67,10 +70,14 @@ class RunEventRepositoryProtocol(Protocol): before_seq: int | None = None, after_seq: int | None = None, user_id: str | None = None, - ) -> list[RunEvent]: ... + ) -> list[RunEvent]: + pass - async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int: ... + async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int: + pass - async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int: ... + async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int: + pass - async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int: ... + async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int: + pass diff --git a/backend/packages/storage/store/repositories/contracts/thread_meta.py b/backend/packages/storage/store/repositories/contracts/thread_meta.py index 8222ca0e4..228fc2447 100644 --- a/backend/packages/storage/store/repositories/contracts/thread_meta.py +++ b/backend/packages/storage/store/repositories/contracts/thread_meta.py @@ -35,9 +35,11 @@ class ThreadMeta(BaseModel): class ThreadMetaRepositoryProtocol(Protocol): - async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta: ... + async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta: + pass - async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None: ... + async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None: + pass async def update_thread_meta( self, @@ -46,9 +48,11 @@ class ThreadMetaRepositoryProtocol(Protocol): display_name: str | None = None, status: str | None = None, metadata: dict[str, Any] | None = None, - ) -> None: ... + ) -> None: + pass - async def delete_thread(self, thread_id: str) -> None: ... + async def delete_thread(self, thread_id: str) -> None: + pass async def search_threads( self, @@ -59,4 +63,5 @@ class ThreadMetaRepositoryProtocol(Protocol): assistant_id: str | None = None, limit: int = 100, offset: int = 0, - ) -> list[ThreadMeta]: ... + ) -> list[ThreadMeta]: + pass diff --git a/backend/packages/storage/store/repositories/contracts/user.py b/backend/packages/storage/store/repositories/contracts/user.py index 828ecf0fe..7fb678c0b 100644 --- a/backend/packages/storage/store/repositories/contracts/user.py +++ b/backend/packages/storage/store/repositories/contracts/user.py @@ -39,18 +39,26 @@ class User(BaseModel): class UserRepositoryProtocol(Protocol): - async def create_user(self, data: UserCreate) -> User: ... + async def create_user(self, data: UserCreate) -> User: + pass - async def get_user_by_id(self, user_id: str) -> User | None: ... + async def get_user_by_id(self, user_id: str) -> User | None: + pass - async def get_user_by_email(self, email: str) -> User | None: ... + async def get_user_by_email(self, email: str) -> User | None: + pass - async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: ... + async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: + pass - async def get_first_admin(self) -> User | None: ... + async def get_first_admin(self) -> User | None: + pass - async def update_user(self, data: User) -> User: ... + async def update_user(self, data: User) -> User: + pass - async def count_users(self) -> int: ... + async def count_users(self) -> int: + pass - async def count_admin_users(self) -> int: ... + async def count_admin_users(self) -> int: + pass diff --git a/backend/packages/storage/store/repositories/db/run_event.py b/backend/packages/storage/store/repositories/db/run_event.py index df5c005e6..06ea8d65b 100644 --- a/backend/packages/storage/store/repositories/db/run_event.py +++ b/backend/packages/storage/store/repositories/db/run_event.py @@ -17,20 +17,25 @@ _SEQ_PROCESS_BITS = 9 _SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS) _SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS _SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS -_last_seq_millis = 0 -_seq_lock = threading.Lock() -def _allocate_sequence_base(batch_size: int) -> int: - if batch_size >= _SEQ_COUNTER_LIMIT: - raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}") +class _SequenceAllocator: + def __init__(self) -> None: + self._last_millis = 0 + self._lock = threading.Lock() - global _last_seq_millis - now_ms = time.time_ns() // 1_000_000 - with _seq_lock: - seq_ms = max(now_ms, _last_seq_millis + 1) - _last_seq_millis = seq_ms - return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS) + def allocate_base(self, batch_size: int) -> int: + if batch_size >= _SEQ_COUNTER_LIMIT: + raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}") + + now_ms = time.time_ns() // 1_000_000 + with self._lock: + seq_ms = max(now_ms, self._last_millis + 1) + self._last_millis = seq_ms + return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS) + + +_sequence_allocator = _SequenceAllocator() def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]: @@ -75,7 +80,7 @@ class DbRunEventRepository(RunEventRepositoryProtocol): if not events: return [] - seq_base = _allocate_sequence_base(len(events)) + seq_base = _sequence_allocator.allocate_base(len(events)) rows: list[RunEventModel] = []