fix(storage): address repository review feedback

This commit is contained in:
rayhpeng
2026-05-13 12:51:45 +08:00
parent d3066a1746
commit 11a9041b65
13 changed files with 140 additions and 65 deletions
@@ -6,20 +6,17 @@ from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
from store.common import DataBaseType
from store.config.app_config import get_app_config
from store.utils import get_timezone
timezone = get_timezone()
app_config = get_app_config()
# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT)
_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger
def current_time() -> datetime:
return get_timezone().now()
id_key = Annotated[
int,
mapped_column(
_id_type,
BigInteger().with_variant(Integer, "sqlite"),
primary_key=True,
unique=True,
index=True,
@@ -33,9 +30,14 @@ id_key = Annotated[
class UniversalText(TypeDecorator[str]):
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text
impl = Text
cache_ok = True
def load_dialect_impl(self, dialect): # noqa: ANN001
if dialect.name == "mysql":
return dialect.type_descriptor(LONGTEXT())
return dialect.type_descriptor(Text())
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
return value
@@ -54,11 +56,13 @@ class TimeZone(TypeDecorator[datetime]):
return datetime
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.utcoffset() != timezone.now().utcoffset():
value = timezone.from_datetime(value)
return value
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.tzinfo is None:
value = value.replace(tzinfo=timezone.tz_info)
return value
@@ -70,14 +74,14 @@ class DateTimeMixin(MappedAsDataclass):
created_time: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=timezone.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
TimeZone,
init=False,
onupdate=timezone.now,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -70,6 +70,8 @@ def _create_database_url(storage_config: StorageConfig) -> URL:
url = make_url(storage_config.database_url)
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
url = url.set(drivername="postgresql+asyncpg")
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
url = url.set(drivername="mysql+aiomysql")
else:
url = URL.create(
drivername=driver,
@@ -34,6 +34,8 @@ class RunEvent(BaseModel):
class RunEventRepositoryProtocol(Protocol):
# Sequence values are time-ordered integer cursors. The application layer
# owns the single-writer invariant for a thread while a run is active.
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
async def list_messages(
@@ -1,6 +1,9 @@
from __future__ import annotations
import json
import secrets
import threading
import time
from typing import Any
from sqlalchemy import delete, func, select
@@ -9,6 +12,26 @@ from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
from store.repositories.models.run_event import RunEvent as RunEventModel
_SEQ_COUNTER_BITS = 12
_SEQ_PROCESS_BITS = 9
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
_last_seq_millis = 0
_seq_lock = threading.Lock()
def _allocate_sequence_base(batch_size: int) -> int:
if batch_size >= _SEQ_COUNTER_LIMIT:
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
global _last_seq_millis
now_ms = time.time_ns() // 1_000_000
with _seq_lock:
seq_ms = max(now_ms, _last_seq_millis + 1)
_last_seq_millis = seq_ms
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
if not isinstance(content, str):
@@ -52,28 +75,17 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
if not events:
return []
thread_ids = {event.thread_id for event in events}
seq_by_thread: dict[str, int] = {}
for thread_id in thread_ids:
max_seq = await self._session.scalar(
select(RunEventModel.seq)
.where(RunEventModel.thread_id == thread_id)
.order_by(RunEventModel.seq.desc())
.limit(1)
.with_for_update()
)
seq_by_thread[thread_id] = max_seq or 0
seq_base = _allocate_sequence_base(len(events))
rows: list[RunEventModel] = []
for event in events:
seq_by_thread[event.thread_id] += 1
for index, event in enumerate(events, start=1):
content, metadata = _serialize_content(event.content, dict(event.metadata))
row = RunEventModel(
thread_id=event.thread_id,
run_id=event.run_id,
user_id=event.user_id,
seq=seq_by_thread[event.thread_id],
seq=seq_base + index,
event_type=event.event_type,
category=event.category,
content=content,
@@ -5,10 +5,7 @@ from datetime import datetime
from sqlalchemy import Integer, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Feedback(DataClassBase):
@@ -33,7 +30,7 @@ class Feedback(DataClassBase):
"created_at",
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -6,10 +6,7 @@ from typing import Any
from sqlalchemy import JSON, Index, Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Run(DataClassBase):
@@ -51,7 +48,7 @@ class Run(DataClassBase):
"created_at",
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -60,7 +57,7 @@ class Run(DataClassBase):
TimeZone,
init=False,
default=None,
onupdate=_tz.now,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -3,13 +3,16 @@ from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, Index, Integer, String, UniqueConstraint
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import (
DataClassBase,
TimeZone,
UniversalText,
current_time,
id_key,
)
class RunEvent(DataClassBase):
@@ -31,13 +34,13 @@ class RunEvent(DataClassBase):
category: Mapped[str] = mapped_column(String(16), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
seq: Mapped[int] = mapped_column(Integer, default=0, index=True)
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
content: Mapped[str] = mapped_column(UniversalText, default="")
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Event timestamp",
)
@@ -6,10 +6,7 @@ from typing import Any
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class ThreadMeta(DataClassBase):
@@ -31,7 +28,7 @@ class ThreadMeta(DataClassBase):
"created_at",
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -40,7 +37,7 @@ class ThreadMeta(DataClassBase):
TimeZone,
init=False,
default=None,
onupdate=_tz.now,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -5,10 +5,7 @@ from datetime import datetime
from sqlalchemy import Boolean, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone
from store.utils import get_timezone
_tz = get_timezone()
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class User(DataClassBase):
@@ -39,7 +36,7 @@ class User(DataClassBase):
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=_tz.now,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
+1
View File
@@ -26,6 +26,7 @@ dependencies = [
[project.optional-dependencies]
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
mysql = ["deerflow-storage[mysql]"]
[dependency-groups]
dev = [
@@ -1,6 +1,8 @@
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
from types import SimpleNamespace
@@ -62,6 +64,30 @@ def test_database_postgres_config_preserves_url_and_pool_options():
assert url.database == "deerflow"
def test_mysql_database_url_is_normalized_to_async_driver():
storage = StorageConfig(
driver="mysql",
database_url="mysql://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+aiomysql"
assert url.database == "deerflow"
def test_mysql_async_database_url_is_preserved():
storage = StorageConfig(
driver="mysql",
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+asyncmy"
assert url.database == "deerflow"
def test_database_postgres_requires_url():
database = SimpleNamespace(backend="postgres", postgres_url="")
@@ -74,3 +100,25 @@ def test_unsupported_database_backend_rejected():
with pytest.raises(ValueError, match="Unsupported database backend"):
storage_config_from_database_config(database)
def test_storage_models_import_without_config_file(tmp_path):
env = os.environ.copy()
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
result = subprocess.run(
[
sys.executable,
"-c",
"from store.persistence.base_model import UniversalText, id_key; "
"from store.repositories.models import RunEvent; "
"print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
],
check=False,
capture_output=True,
env=env,
text=True,
)
assert result.returncode == 0, result.stderr
assert "UniversalText run_events" in result.stdout
+22 -11
View File
@@ -348,25 +348,23 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
)
await session.commit()
assert [(row.thread_id, row.seq) for row in rows] == [
("thread-1", 1),
("thread-1", 2),
("thread-1", 3),
("thread-2", 1),
]
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
assert rows[1].seq == rows[0].seq + 1
assert rows[2].seq == rows[1].seq + 1
assert rows[0].content == {"role": "user", "content": "hello"}
assert rows[0].metadata == {"source": "input", "content_is_json": True}
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
messages = await repo.list_messages("thread-1", limit=2)
assert [event.seq for event in messages] == [1, 3]
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
assert await repo.count_messages("thread-1") == 2
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
assert [event.seq for event in after] == [1]
before = await repo.list_messages("thread-1", before_seq=3, limit=5)
assert [event.seq for event in before] == [1]
assert [event.seq for event in after] == [rows[0].seq]
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
assert [event.seq for event in before] == [rows[0].seq]
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
assert [event.content for event in events] == ["tool-call"]
@@ -378,7 +376,20 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
remaining = await repo.list_events("thread-1", "run-2")
assert [event.seq for event in remaining] == [3]
assert [event.seq for event in remaining] == [rows[2].seq]
assert await repo.count_messages("thread-2") == 0
later = await repo.append_batch(
[
RunEventCreate(
thread_id="thread-1",
run_id="run-4",
event_type="message",
category="message",
content="after-delete",
)
]
)
assert later[0].seq > rows[2].seq
finally:
await persistence.aclose()
+5 -1
View File
@@ -777,6 +777,9 @@ dependencies = [
]
[package.optional-dependencies]
mysql = [
{ name = "deerflow-storage", extra = ["mysql"] },
]
postgres = [
{ name = "deerflow-harness", extra = ["postgres"] },
{ name = "deerflow-storage", extra = ["postgres"] },
@@ -796,6 +799,7 @@ requires-dist = [
{ name = "deerflow-harness", editable = "packages/harness" },
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
{ name = "deerflow-storage", editable = "packages/storage" },
{ name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" },
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
{ name = "email-validator", specifier = ">=2.0.0" },
@@ -812,7 +816,7 @@ requires-dist = [
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
]
provides-extras = ["postgres"]
provides-extras = ["postgres", "mysql"]
[package.metadata.requires-dev]
dev = [