fix(storage): address repository review feedback
This commit is contained in:
@@ -6,20 +6,17 @@ from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.utils import get_timezone
|
||||
|
||||
timezone = get_timezone()
|
||||
app_config = get_app_config()
|
||||
|
||||
# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT)
|
||||
_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger
|
||||
def current_time() -> datetime:
|
||||
return get_timezone().now()
|
||||
|
||||
|
||||
id_key = Annotated[
|
||||
int,
|
||||
mapped_column(
|
||||
_id_type,
|
||||
BigInteger().with_variant(Integer, "sqlite"),
|
||||
primary_key=True,
|
||||
unique=True,
|
||||
index=True,
|
||||
@@ -33,9 +30,14 @@ id_key = Annotated[
|
||||
class UniversalText(TypeDecorator[str]):
|
||||
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
||||
|
||||
impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text
|
||||
impl = Text
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect): # noqa: ANN001
|
||||
if dialect.name == "mysql":
|
||||
return dialect.type_descriptor(LONGTEXT())
|
||||
return dialect.type_descriptor(Text())
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
@@ -54,11 +56,13 @@ class TimeZone(TypeDecorator[datetime]):
|
||||
return datetime
|
||||
|
||||
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
timezone = get_timezone()
|
||||
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
||||
value = timezone.from_datetime(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
timezone = get_timezone()
|
||||
if value is not None and value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.tz_info)
|
||||
return value
|
||||
@@ -70,14 +74,14 @@ class DateTimeMixin(MappedAsDataclass):
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=timezone.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
onupdate=timezone.now,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
@@ -70,6 +70,8 @@ def _create_database_url(storage_config: StorageConfig) -> URL:
|
||||
url = make_url(storage_config.database_url)
|
||||
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
|
||||
url = url.set(drivername="postgresql+asyncpg")
|
||||
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
|
||||
url = url.set(drivername="mysql+aiomysql")
|
||||
else:
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
|
||||
@@ -34,6 +34,8 @@ class RunEvent(BaseModel):
|
||||
|
||||
|
||||
class RunEventRepositoryProtocol(Protocol):
|
||||
# Sequence values are time-ordered integer cursors. The application layer
|
||||
# owns the single-writer invariant for a thread while a run is active.
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
|
||||
|
||||
async def list_messages(
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
@@ -9,6 +12,26 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
||||
from store.repositories.models.run_event import RunEvent as RunEventModel
|
||||
|
||||
_SEQ_COUNTER_BITS = 12
|
||||
_SEQ_PROCESS_BITS = 9
|
||||
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
|
||||
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
|
||||
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
|
||||
_last_seq_millis = 0
|
||||
_seq_lock = threading.Lock()
|
||||
|
||||
|
||||
def _allocate_sequence_base(batch_size: int) -> int:
|
||||
if batch_size >= _SEQ_COUNTER_LIMIT:
|
||||
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
|
||||
|
||||
global _last_seq_millis
|
||||
now_ms = time.time_ns() // 1_000_000
|
||||
with _seq_lock:
|
||||
seq_ms = max(now_ms, _last_seq_millis + 1)
|
||||
_last_seq_millis = seq_ms
|
||||
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
|
||||
|
||||
|
||||
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
if not isinstance(content, str):
|
||||
@@ -52,28 +75,17 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
if not events:
|
||||
return []
|
||||
|
||||
thread_ids = {event.thread_id for event in events}
|
||||
seq_by_thread: dict[str, int] = {}
|
||||
for thread_id in thread_ids:
|
||||
max_seq = await self._session.scalar(
|
||||
select(RunEventModel.seq)
|
||||
.where(RunEventModel.thread_id == thread_id)
|
||||
.order_by(RunEventModel.seq.desc())
|
||||
.limit(1)
|
||||
.with_for_update()
|
||||
)
|
||||
seq_by_thread[thread_id] = max_seq or 0
|
||||
seq_base = _allocate_sequence_base(len(events))
|
||||
|
||||
rows: list[RunEventModel] = []
|
||||
|
||||
for event in events:
|
||||
seq_by_thread[event.thread_id] += 1
|
||||
for index, event in enumerate(events, start=1):
|
||||
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
||||
row = RunEventModel(
|
||||
thread_id=event.thread_id,
|
||||
run_id=event.run_id,
|
||||
user_id=event.user_id,
|
||||
seq=seq_by_thread[event.thread_id],
|
||||
seq=seq_base + index,
|
||||
event_type=event.event_type,
|
||||
category=event.category,
|
||||
content=content,
|
||||
|
||||
@@ -5,10 +5,7 @@ from datetime import datetime
|
||||
from sqlalchemy import Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||
|
||||
|
||||
class Feedback(DataClassBase):
|
||||
@@ -33,7 +30,7 @@ class Feedback(DataClassBase):
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
|
||||
@@ -6,10 +6,7 @@ from typing import Any
|
||||
from sqlalchemy import JSON, Index, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||
|
||||
|
||||
class Run(DataClassBase):
|
||||
@@ -51,7 +48,7 @@ class Run(DataClassBase):
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -60,7 +57,7 @@ class Run(DataClassBase):
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=_tz.now,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
@@ -3,13 +3,16 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Index, Integer, String, UniqueConstraint
|
||||
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import (
|
||||
DataClassBase,
|
||||
TimeZone,
|
||||
UniversalText,
|
||||
current_time,
|
||||
id_key,
|
||||
)
|
||||
|
||||
|
||||
class RunEvent(DataClassBase):
|
||||
@@ -31,13 +34,13 @@ class RunEvent(DataClassBase):
|
||||
category: Mapped[str] = mapped_column(String(16), index=True)
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
seq: Mapped[int] = mapped_column(Integer, default=0, index=True)
|
||||
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
|
||||
content: Mapped[str] = mapped_column(UniversalText, default="")
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Event timestamp",
|
||||
)
|
||||
|
||||
@@ -6,10 +6,7 @@ from typing import Any
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||
|
||||
|
||||
class ThreadMeta(DataClassBase):
|
||||
@@ -31,7 +28,7 @@ class ThreadMeta(DataClassBase):
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -40,7 +37,7 @@ class ThreadMeta(DataClassBase):
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=_tz.now,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
@@ -5,10 +5,7 @@ from datetime import datetime
|
||||
from sqlalchemy import Boolean, Index, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone
|
||||
from store.utils import get_timezone
|
||||
|
||||
_tz = get_timezone()
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||
|
||||
|
||||
class User(DataClassBase):
|
||||
@@ -39,7 +36,7 @@ class User(DataClassBase):
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=_tz.now,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
|
||||
mysql = ["deerflow-storage[mysql]"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
@@ -62,6 +64,30 @@ def test_database_postgres_config_preserves_url_and_pool_options():
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_mysql_database_url_is_normalized_to_async_driver():
|
||||
storage = StorageConfig(
|
||||
driver="mysql",
|
||||
database_url="mysql://user:pass@db.example:3306/deerflow",
|
||||
)
|
||||
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert url.drivername == "mysql+aiomysql"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_mysql_async_database_url_is_preserved():
|
||||
storage = StorageConfig(
|
||||
driver="mysql",
|
||||
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
|
||||
)
|
||||
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert url.drivername == "mysql+asyncmy"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_database_postgres_requires_url():
|
||||
database = SimpleNamespace(backend="postgres", postgres_url="")
|
||||
|
||||
@@ -74,3 +100,25 @@ def test_unsupported_database_backend_rejected():
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_storage_models_import_without_config_file(tmp_path):
|
||||
env = os.environ.copy()
|
||||
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
"from store.persistence.base_model import UniversalText, id_key; "
|
||||
"from store.repositories.models import RunEvent; "
|
||||
"print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
|
||||
],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
env=env,
|
||||
text=True,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, result.stderr
|
||||
assert "UniversalText run_events" in result.stdout
|
||||
|
||||
@@ -348,25 +348,23 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert [(row.thread_id, row.seq) for row in rows] == [
|
||||
("thread-1", 1),
|
||||
("thread-1", 2),
|
||||
("thread-1", 3),
|
||||
("thread-2", 1),
|
||||
]
|
||||
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
|
||||
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
|
||||
assert rows[1].seq == rows[0].seq + 1
|
||||
assert rows[2].seq == rows[1].seq + 1
|
||||
assert rows[0].content == {"role": "user", "content": "hello"}
|
||||
assert rows[0].metadata == {"source": "input", "content_is_json": True}
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
messages = await repo.list_messages("thread-1", limit=2)
|
||||
assert [event.seq for event in messages] == [1, 3]
|
||||
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
|
||||
assert await repo.count_messages("thread-1") == 2
|
||||
|
||||
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
|
||||
assert [event.seq for event in after] == [1]
|
||||
before = await repo.list_messages("thread-1", before_seq=3, limit=5)
|
||||
assert [event.seq for event in before] == [1]
|
||||
assert [event.seq for event in after] == [rows[0].seq]
|
||||
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
|
||||
assert [event.seq for event in before] == [rows[0].seq]
|
||||
|
||||
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
|
||||
assert [event.content for event in events] == ["tool-call"]
|
||||
@@ -378,7 +376,20 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
remaining = await repo.list_events("thread-1", "run-2")
|
||||
assert [event.seq for event in remaining] == [3]
|
||||
assert [event.seq for event in remaining] == [rows[2].seq]
|
||||
assert await repo.count_messages("thread-2") == 0
|
||||
|
||||
later = await repo.append_batch(
|
||||
[
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-4",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="after-delete",
|
||||
)
|
||||
]
|
||||
)
|
||||
assert later[0].seq > rows[2].seq
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
Generated
+5
-1
@@ -777,6 +777,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
mysql = [
|
||||
{ name = "deerflow-storage", extra = ["mysql"] },
|
||||
]
|
||||
postgres = [
|
||||
{ name = "deerflow-harness", extra = ["postgres"] },
|
||||
{ name = "deerflow-storage", extra = ["postgres"] },
|
||||
@@ -796,6 +799,7 @@ requires-dist = [
|
||||
{ name = "deerflow-harness", editable = "packages/harness" },
|
||||
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
||||
{ name = "deerflow-storage", editable = "packages/storage" },
|
||||
{ name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" },
|
||||
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
|
||||
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
||||
{ name = "email-validator", specifier = ">=2.0.0" },
|
||||
@@ -812,7 +816,7 @@ requires-dist = [
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
||||
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
||||
]
|
||||
provides-extras = ["postgres"]
|
||||
provides-extras = ["postgres", "mysql"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
|
||||
Reference in New Issue
Block a user