fix(storage): address code quality comments
This commit is contained in:
@@ -39,9 +39,15 @@ class FeedbackAggregate(TypedDict):
|
||||
|
||||
|
||||
class FeedbackRepositoryProtocol(Protocol):
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback: ...
|
||||
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback: ...
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None: ...
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
pass
|
||||
|
||||
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
pass
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
pass
|
||||
|
||||
async def list_feedback_by_run(
|
||||
self,
|
||||
run_id: str,
|
||||
@@ -49,14 +55,23 @@ class FeedbackRepositoryProtocol(Protocol):
|
||||
thread_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]: ...
|
||||
) -> list[Feedback]:
|
||||
pass
|
||||
|
||||
async def list_feedback_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]: ...
|
||||
async def delete_feedback(self, feedback_id: str) -> bool: ...
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool: ...
|
||||
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate: ...
|
||||
) -> list[Feedback]:
|
||||
pass
|
||||
|
||||
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||
pass
|
||||
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||
pass
|
||||
|
||||
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
|
||||
pass
|
||||
|
||||
@@ -52,8 +52,12 @@ class Run(BaseModel):
|
||||
|
||||
|
||||
class RunRepositoryProtocol(Protocol):
|
||||
async def create_run(self, data: RunCreate) -> Run: ...
|
||||
async def get_run(self, run_id: str) -> Run | None: ...
|
||||
async def create_run(self, data: RunCreate) -> Run:
|
||||
pass
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
pass
|
||||
|
||||
async def list_runs_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
@@ -61,10 +65,18 @@ class RunRepositoryProtocol(Protocol):
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[Run]: ...
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None: ...
|
||||
async def delete_run(self, run_id: str) -> None: ...
|
||||
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]: ...
|
||||
) -> list[Run]:
|
||||
pass
|
||||
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
|
||||
pass
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
|
||||
pass
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
@@ -81,5 +93,8 @@ class RunRepositoryProtocol(Protocol):
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None: ...
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: ...
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
pass
|
||||
|
||||
@@ -36,7 +36,8 @@ class RunEvent(BaseModel):
|
||||
class RunEventRepositoryProtocol(Protocol):
|
||||
# Sequence values are time-ordered integer cursors. The application layer
|
||||
# owns the single-writer invariant for a thread while a run is active.
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
@@ -46,7 +47,8 @@ class RunEventRepositoryProtocol(Protocol):
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]: ...
|
||||
) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
@@ -56,7 +58,8 @@ class RunEventRepositoryProtocol(Protocol):
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]: ...
|
||||
) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
@@ -67,10 +70,14 @@ class RunEventRepositoryProtocol(Protocol):
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]: ...
|
||||
) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int: ...
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
pass
|
||||
|
||||
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int: ...
|
||||
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
pass
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int: ...
|
||||
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
|
||||
pass
|
||||
|
||||
@@ -35,9 +35,11 @@ class ThreadMeta(BaseModel):
|
||||
|
||||
|
||||
class ThreadMetaRepositoryProtocol(Protocol):
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta: ...
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
pass
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None: ...
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
pass
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
@@ -46,9 +48,11 @@ class ThreadMetaRepositoryProtocol(Protocol):
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None: ...
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
@@ -59,4 +63,5 @@ class ThreadMetaRepositoryProtocol(Protocol):
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]: ...
|
||||
) -> list[ThreadMeta]:
|
||||
pass
|
||||
|
||||
@@ -39,18 +39,26 @@ class User(BaseModel):
|
||||
|
||||
|
||||
class UserRepositoryProtocol(Protocol):
|
||||
async def create_user(self, data: UserCreate) -> User: ...
|
||||
async def create_user(self, data: UserCreate) -> User:
|
||||
pass
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None: ...
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
pass
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None: ...
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
pass
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None: ...
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
pass
|
||||
|
||||
async def get_first_admin(self) -> User | None: ...
|
||||
async def get_first_admin(self) -> User | None:
|
||||
pass
|
||||
|
||||
async def update_user(self, data: User) -> User: ...
|
||||
async def update_user(self, data: User) -> User:
|
||||
pass
|
||||
|
||||
async def count_users(self) -> int: ...
|
||||
async def count_users(self) -> int:
|
||||
pass
|
||||
|
||||
async def count_admin_users(self) -> int: ...
|
||||
async def count_admin_users(self) -> int:
|
||||
pass
|
||||
|
||||
@@ -17,20 +17,25 @@ _SEQ_PROCESS_BITS = 9
|
||||
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
|
||||
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
|
||||
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
|
||||
_last_seq_millis = 0
|
||||
_seq_lock = threading.Lock()
|
||||
|
||||
|
||||
def _allocate_sequence_base(batch_size: int) -> int:
|
||||
if batch_size >= _SEQ_COUNTER_LIMIT:
|
||||
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
|
||||
class _SequenceAllocator:
|
||||
def __init__(self) -> None:
|
||||
self._last_millis = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
global _last_seq_millis
|
||||
now_ms = time.time_ns() // 1_000_000
|
||||
with _seq_lock:
|
||||
seq_ms = max(now_ms, _last_seq_millis + 1)
|
||||
_last_seq_millis = seq_ms
|
||||
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
|
||||
def allocate_base(self, batch_size: int) -> int:
|
||||
if batch_size >= _SEQ_COUNTER_LIMIT:
|
||||
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
|
||||
|
||||
now_ms = time.time_ns() // 1_000_000
|
||||
with self._lock:
|
||||
seq_ms = max(now_ms, self._last_millis + 1)
|
||||
self._last_millis = seq_ms
|
||||
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
|
||||
|
||||
|
||||
_sequence_allocator = _SequenceAllocator()
|
||||
|
||||
|
||||
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
@@ -75,7 +80,7 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
if not events:
|
||||
return []
|
||||
|
||||
seq_base = _allocate_sequence_base(len(events))
|
||||
seq_base = _sequence_allocator.allocate_base(len(events))
|
||||
|
||||
rows: list[RunEventModel] = []
|
||||
|
||||
|
||||
Reference in New Issue
Block a user