mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
fix(storage): address repository review feedback
This commit is contained in:
@@ -6,20 +6,17 @@ from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.utils import get_timezone
|
||||
|
||||
timezone = get_timezone()
|
||||
app_config = get_app_config()
|
||||
|
||||
# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT)
|
||||
_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger
|
||||
def current_time() -> datetime:
|
||||
return get_timezone().now()
|
||||
|
||||
|
||||
id_key = Annotated[
|
||||
int,
|
||||
mapped_column(
|
||||
_id_type,
|
||||
BigInteger().with_variant(Integer, "sqlite"),
|
||||
primary_key=True,
|
||||
unique=True,
|
||||
index=True,
|
||||
@@ -33,9 +30,14 @@ id_key = Annotated[
|
||||
class UniversalText(TypeDecorator[str]):
|
||||
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
||||
|
||||
impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text
|
||||
impl = Text
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect): # noqa: ANN001
|
||||
if dialect.name == "mysql":
|
||||
return dialect.type_descriptor(LONGTEXT())
|
||||
return dialect.type_descriptor(Text())
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
@@ -54,11 +56,13 @@ class TimeZone(TypeDecorator[datetime]):
|
||||
return datetime
|
||||
|
||||
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
timezone = get_timezone()
|
||||
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
||||
value = timezone.from_datetime(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
timezone = get_timezone()
|
||||
if value is not None and value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.tz_info)
|
||||
return value
|
||||
@@ -70,14 +74,14 @@ class DateTimeMixin(MappedAsDataclass):
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=timezone.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
onupdate=timezone.now,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
@@ -70,6 +70,8 @@ def _create_database_url(storage_config: StorageConfig) -> URL:
|
||||
url = make_url(storage_config.database_url)
|
||||
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
|
||||
url = url.set(drivername="postgresql+asyncpg")
|
||||
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
|
||||
url = url.set(drivername="mysql+aiomysql")
|
||||
else:
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
|
||||
@@ -34,6 +34,8 @@ class RunEvent(BaseModel):
|
||||
|
||||
|
||||
class RunEventRepositoryProtocol(Protocol):
|
||||
# Sequence values are time-ordered integer cursors. The application layer
|
||||
# owns the single-writer invariant for a thread while a run is active.
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
|
||||
|
||||
async def list_messages(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
@@ -9,6 +12,26 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
||||
from store.repositories.models.run_event import RunEvent as RunEventModel
|
||||
|
||||
_SEQ_COUNTER_BITS = 12
|
||||
_SEQ_PROCESS_BITS = 9
|
||||
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
|
||||
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
|
||||
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
|
||||
_last_seq_millis = 0
|
||||
_seq_lock = threading.Lock()
|
||||
|
||||
|
||||
def _allocate_sequence_base(batch_size: int) -> int:
|
||||
if batch_size >= _SEQ_COUNTER_LIMIT:
|
||||
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
|
||||
|
||||
global _last_seq_millis
|
||||
now_ms = time.time_ns() // 1_000_000
|
||||
with _seq_lock:
|
||||
seq_ms = max(now_ms, _last_seq_millis + 1)
|
||||
_last_seq_millis = seq_ms
|
||||
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
|
||||
|
||||
|
||||
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
if not isinstance(content, str):
|
||||
@@ -52,28 +75,17 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
if not events:
|
||||
return []
|
||||
|
||||
thread_ids = {event.thread_id for event in events}
|
||||
seq_by_thread: dict[str, int] = {}
|
||||
for thread_id in thread_ids:
|
||||
max_seq = await self._session.scalar(
|
||||
select(RunEventModel.seq)
|
||||
.where(RunEventModel.thread_id == thread_id)
|
||||
.order_by(RunEventModel.seq.desc())
|
||||
.limit(1)
|
||||
.with_for_update()
|
||||
)
|
||||
seq_by_thread[thread_id] = max_seq or 0
|
||||
seq_base = _allocate_sequence_base(len(events))
|
||||
|
||||
rows: list[RunEventModel] = []
|
||||
|
||||
for event in events:
|
||||
seq_by_thread[event.thread_id] += 1
|
||||
for index, event in enumerate(events, start=1):
|
||||
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
||||
row = RunEventModel(
|
||||
thread_id=event.thread_id,
|
||||
run_id=event.run_id,
|
||||
user_id=event.user_id,
|
||||
seq=seq_by_thread[event.thread_id],
|
||||
seq=seq_base + index,
|
||||
event_type=event.event_type,
|
||||
category=event.category,
|
||||
content=content,
|
||||
|
||||
@@ -5,10 +5,7 @@ from datetime import datetime
|
||||
from sqlalchemy import Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||
|
||||
|
||||
class Feedback(DataClassBase):
|
||||
@@ -33,7 +30,7 @@ class Feedback(DataClassBase):
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
|
||||
@@ -6,10 +6,7 @@ from typing import Any
|
||||
from sqlalchemy import JSON, Index, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||
|
||||
|
||||
class Run(DataClassBase):
|
||||
@@ -51,7 +48,7 @@ class Run(DataClassBase):
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -60,7 +57,7 @@ class Run(DataClassBase):
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=_tz.now,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
@@ -3,13 +3,16 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Index, Integer, String, UniqueConstraint
|
||||
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import (
|
||||
DataClassBase,
|
||||
TimeZone,
|
||||
UniversalText,
|
||||
current_time,
|
||||
id_key,
|
||||
)
|
||||
|
||||
|
||||
class RunEvent(DataClassBase):
|
||||
@@ -31,13 +34,13 @@ class RunEvent(DataClassBase):
|
||||
category: Mapped[str] = mapped_column(String(16), index=True)
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
seq: Mapped[int] = mapped_column(Integer, default=0, index=True)
|
||||
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
|
||||
content: Mapped[str] = mapped_column(UniversalText, default="")
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Event timestamp",
|
||||
)
|
||||
|
||||
@@ -6,10 +6,7 @@ from typing import Any
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||
|
||||
|
||||
class ThreadMeta(DataClassBase):
|
||||
@@ -31,7 +28,7 @@ class ThreadMeta(DataClassBase):
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -40,7 +37,7 @@ class ThreadMeta(DataClassBase):
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=_tz.now,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
@@ -5,10 +5,7 @@ from datetime import datetime
|
||||
from sqlalchemy import Boolean, Index, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||
|
||||
|
||||
class User(DataClassBase):
|
||||
@@ -39,7 +36,7 @@ class User(DataClassBase):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user