fix(storage): address repository review feedback

This commit is contained in:
rayhpeng
2026-05-13 12:51:45 +08:00
parent d3066a1746
commit 11a9041b65
13 changed files with 140 additions and 65 deletions
@@ -6,20 +6,17 @@ from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
from store.common import DataBaseType
from store.config.app_config import get_app_config
from store.utils import get_timezone from store.utils import get_timezone
timezone = get_timezone()
app_config = get_app_config()
# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT) def current_time() -> datetime:
_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger return get_timezone().now()
id_key = Annotated[ id_key = Annotated[
int, int,
mapped_column( mapped_column(
_id_type, BigInteger().with_variant(Integer, "sqlite"),
primary_key=True, primary_key=True,
unique=True, unique=True,
index=True, index=True,
@@ -33,9 +30,14 @@ id_key = Annotated[
class UniversalText(TypeDecorator[str]): class UniversalText(TypeDecorator[str]):
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL).""" """Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text impl = Text
cache_ok = True cache_ok = True
def load_dialect_impl(self, dialect): # noqa: ANN001
if dialect.name == "mysql":
return dialect.type_descriptor(LONGTEXT())
return dialect.type_descriptor(Text())
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001 def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
return value return value
@@ -54,11 +56,13 @@ class TimeZone(TypeDecorator[datetime]):
return datetime return datetime
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001 def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.utcoffset() != timezone.now().utcoffset(): if value is not None and value.utcoffset() != timezone.now().utcoffset():
value = timezone.from_datetime(value) value = timezone.from_datetime(value)
return value return value
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001 def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.tzinfo is None: if value is not None and value.tzinfo is None:
value = value.replace(tzinfo=timezone.tz_info) value = value.replace(tzinfo=timezone.tz_info)
return value return value
@@ -70,14 +74,14 @@ class DateTimeMixin(MappedAsDataclass):
created_time: Mapped[datetime] = mapped_column( created_time: Mapped[datetime] = mapped_column(
TimeZone, TimeZone,
init=False, init=False,
default_factory=timezone.now, default_factory=current_time,
sort_order=999, sort_order=999,
comment="Created at", comment="Created at",
) )
updated_time: Mapped[datetime | None] = mapped_column( updated_time: Mapped[datetime | None] = mapped_column(
TimeZone, TimeZone,
init=False, init=False,
onupdate=timezone.now, onupdate=current_time,
sort_order=999, sort_order=999,
comment="Updated at", comment="Updated at",
) )
@@ -70,6 +70,8 @@ def _create_database_url(storage_config: StorageConfig) -> URL:
url = make_url(storage_config.database_url) url = make_url(storage_config.database_url)
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql": if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
url = url.set(drivername="postgresql+asyncpg") url = url.set(drivername="postgresql+asyncpg")
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
url = url.set(drivername="mysql+aiomysql")
else: else:
url = URL.create( url = URL.create(
drivername=driver, drivername=driver,
@@ -34,6 +34,8 @@ class RunEvent(BaseModel):
class RunEventRepositoryProtocol(Protocol): class RunEventRepositoryProtocol(Protocol):
# Sequence values are time-ordered integer cursors. The application layer
# owns the single-writer invariant for a thread while a run is active.
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ... async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
async def list_messages( async def list_messages(
@@ -1,6 +1,9 @@
from __future__ import annotations from __future__ import annotations
import json import json
import secrets
import threading
import time
from typing import Any from typing import Any
from sqlalchemy import delete, func, select from sqlalchemy import delete, func, select
@@ -9,6 +12,26 @@ from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
from store.repositories.models.run_event import RunEvent as RunEventModel from store.repositories.models.run_event import RunEvent as RunEventModel
_SEQ_COUNTER_BITS = 12
_SEQ_PROCESS_BITS = 9
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
_last_seq_millis = 0
_seq_lock = threading.Lock()
def _allocate_sequence_base(batch_size: int) -> int:
if batch_size >= _SEQ_COUNTER_LIMIT:
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
global _last_seq_millis
now_ms = time.time_ns() // 1_000_000
with _seq_lock:
seq_ms = max(now_ms, _last_seq_millis + 1)
_last_seq_millis = seq_ms
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]: def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
if not isinstance(content, str): if not isinstance(content, str):
@@ -52,28 +75,17 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
if not events: if not events:
return [] return []
thread_ids = {event.thread_id for event in events} seq_base = _allocate_sequence_base(len(events))
seq_by_thread: dict[str, int] = {}
for thread_id in thread_ids:
max_seq = await self._session.scalar(
select(RunEventModel.seq)
.where(RunEventModel.thread_id == thread_id)
.order_by(RunEventModel.seq.desc())
.limit(1)
.with_for_update()
)
seq_by_thread[thread_id] = max_seq or 0
rows: list[RunEventModel] = [] rows: list[RunEventModel] = []
for event in events: for index, event in enumerate(events, start=1):
seq_by_thread[event.thread_id] += 1
content, metadata = _serialize_content(event.content, dict(event.metadata)) content, metadata = _serialize_content(event.content, dict(event.metadata))
row = RunEventModel( row = RunEventModel(
thread_id=event.thread_id, thread_id=event.thread_id,
run_id=event.run_id, run_id=event.run_id,
user_id=event.user_id, user_id=event.user_id,
seq=seq_by_thread[event.thread_id], seq=seq_base + index,
event_type=event.event_type, event_type=event.event_type,
category=event.category, category=event.category,
content=content, content=content,
@@ -5,10 +5,7 @@ from datetime import datetime
from sqlalchemy import Integer, String, UniqueConstraint from sqlalchemy import Integer, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
from store.utils import get_timezone
_tz = get_timezone()
class Feedback(DataClassBase): class Feedback(DataClassBase):
@@ -33,7 +30,7 @@ class Feedback(DataClassBase):
"created_at", "created_at",
TimeZone, TimeZone,
init=False, init=False,
default_factory=_tz.now, default_factory=current_time,
sort_order=999, sort_order=999,
comment="Created at", comment="Created at",
) )
@@ -6,10 +6,7 @@ from typing import Any
from sqlalchemy import JSON, Index, Integer, String from sqlalchemy import JSON, Index, Integer, String
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
from store.utils import get_timezone
_tz = get_timezone()
class Run(DataClassBase): class Run(DataClassBase):
@@ -51,7 +48,7 @@ class Run(DataClassBase):
"created_at", "created_at",
TimeZone, TimeZone,
init=False, init=False,
default_factory=_tz.now, default_factory=current_time,
sort_order=999, sort_order=999,
comment="Created at", comment="Created at",
) )
@@ -60,7 +57,7 @@ class Run(DataClassBase):
TimeZone, TimeZone,
init=False, init=False,
default=None, default=None,
onupdate=_tz.now, onupdate=current_time,
sort_order=999, sort_order=999,
comment="Updated at", comment="Updated at",
) )
@@ -3,13 +3,16 @@ from __future__ import annotations
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from sqlalchemy import JSON, Index, Integer, String, UniqueConstraint from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key from store.persistence.base_model import (
from store.utils import get_timezone DataClassBase,
TimeZone,
_tz = get_timezone() UniversalText,
current_time,
id_key,
)
class RunEvent(DataClassBase): class RunEvent(DataClassBase):
@@ -31,13 +34,13 @@ class RunEvent(DataClassBase):
category: Mapped[str] = mapped_column(String(16), index=True) category: Mapped[str] = mapped_column(String(16), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True) user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
seq: Mapped[int] = mapped_column(Integer, default=0, index=True) seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
content: Mapped[str] = mapped_column(UniversalText, default="") content: Mapped[str] = mapped_column(UniversalText, default="")
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict) meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
TimeZone, TimeZone,
init=False, init=False,
default_factory=_tz.now, default_factory=current_time,
sort_order=999, sort_order=999,
comment="Event timestamp", comment="Event timestamp",
) )
@@ -6,10 +6,7 @@ from typing import Any
from sqlalchemy import JSON, String from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone from store.persistence.base_model import DataClassBase, TimeZone, current_time
from store.utils import get_timezone
_tz = get_timezone()
class ThreadMeta(DataClassBase): class ThreadMeta(DataClassBase):
@@ -31,7 +28,7 @@ class ThreadMeta(DataClassBase):
"created_at", "created_at",
TimeZone, TimeZone,
init=False, init=False,
default_factory=_tz.now, default_factory=current_time,
sort_order=999, sort_order=999,
comment="Created at", comment="Created at",
) )
@@ -40,7 +37,7 @@ class ThreadMeta(DataClassBase):
TimeZone, TimeZone,
init=False, init=False,
default=None, default=None,
onupdate=_tz.now, onupdate=current_time,
sort_order=999, sort_order=999,
comment="Updated at", comment="Updated at",
) )
@@ -5,10 +5,7 @@ from datetime import datetime
from sqlalchemy import Boolean, Index, String, text from sqlalchemy import Boolean, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone from store.persistence.base_model import DataClassBase, TimeZone, current_time
from store.utils import get_timezone
_tz = get_timezone()
class User(DataClassBase): class User(DataClassBase):
@@ -39,7 +36,7 @@ class User(DataClassBase):
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
TimeZone, TimeZone,
init=False, init=False,
default_factory=_tz.now, default_factory=current_time,
sort_order=999, sort_order=999,
comment="Created at", comment="Created at",
) )
+1
View File
@@ -26,6 +26,7 @@ dependencies = [
[project.optional-dependencies] [project.optional-dependencies]
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"] postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
mysql = ["deerflow-storage[mysql]"]
[dependency-groups] [dependency-groups]
dev = [ dev = [
@@ -1,6 +1,8 @@
from __future__ import annotations from __future__ import annotations
import os import os
import subprocess
import sys
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
@@ -62,6 +64,30 @@ def test_database_postgres_config_preserves_url_and_pool_options():
assert url.database == "deerflow" assert url.database == "deerflow"
def test_mysql_database_url_is_normalized_to_async_driver():
storage = StorageConfig(
driver="mysql",
database_url="mysql://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+aiomysql"
assert url.database == "deerflow"
def test_mysql_async_database_url_is_preserved():
storage = StorageConfig(
driver="mysql",
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+asyncmy"
assert url.database == "deerflow"
def test_database_postgres_requires_url(): def test_database_postgres_requires_url():
database = SimpleNamespace(backend="postgres", postgres_url="") database = SimpleNamespace(backend="postgres", postgres_url="")
@@ -74,3 +100,25 @@ def test_unsupported_database_backend_rejected():
with pytest.raises(ValueError, match="Unsupported database backend"): with pytest.raises(ValueError, match="Unsupported database backend"):
storage_config_from_database_config(database) storage_config_from_database_config(database)
def test_storage_models_import_without_config_file(tmp_path):
env = os.environ.copy()
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
result = subprocess.run(
[
sys.executable,
"-c",
"from store.persistence.base_model import UniversalText, id_key; "
"from store.repositories.models import RunEvent; "
"print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
],
check=False,
capture_output=True,
env=env,
text=True,
)
assert result.returncode == 0, result.stderr
assert "UniversalText run_events" in result.stdout
+22 -11
View File
@@ -348,25 +348,23 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
) )
await session.commit() await session.commit()
assert [(row.thread_id, row.seq) for row in rows] == [ assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
("thread-1", 1), assert [row.seq for row in rows] == sorted(row.seq for row in rows)
("thread-1", 2), assert rows[1].seq == rows[0].seq + 1
("thread-1", 3), assert rows[2].seq == rows[1].seq + 1
("thread-2", 1),
]
assert rows[0].content == {"role": "user", "content": "hello"} assert rows[0].content == {"role": "user", "content": "hello"}
assert rows[0].metadata == {"source": "input", "content_is_json": True} assert rows[0].metadata == {"source": "input", "content_is_json": True}
async with persistence.session_factory() as session: async with persistence.session_factory() as session:
repo = build_run_event_repository(session) repo = build_run_event_repository(session)
messages = await repo.list_messages("thread-1", limit=2) messages = await repo.list_messages("thread-1", limit=2)
assert [event.seq for event in messages] == [1, 3] assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
assert await repo.count_messages("thread-1") == 2 assert await repo.count_messages("thread-1") == 2
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5) after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
assert [event.seq for event in after] == [1] assert [event.seq for event in after] == [rows[0].seq]
before = await repo.list_messages("thread-1", before_seq=3, limit=5) before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
assert [event.seq for event in before] == [1] assert [event.seq for event in before] == [rows[0].seq]
events = await repo.list_events("thread-1", "run-1", event_types=["tool"]) events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
assert [event.content for event in events] == ["tool-call"] assert [event.content for event in events] == ["tool-call"]
@@ -378,7 +376,20 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
async with persistence.session_factory() as session: async with persistence.session_factory() as session:
repo = build_run_event_repository(session) repo = build_run_event_repository(session)
remaining = await repo.list_events("thread-1", "run-2") remaining = await repo.list_events("thread-1", "run-2")
assert [event.seq for event in remaining] == [3] assert [event.seq for event in remaining] == [rows[2].seq]
assert await repo.count_messages("thread-2") == 0 assert await repo.count_messages("thread-2") == 0
later = await repo.append_batch(
[
RunEventCreate(
thread_id="thread-1",
run_id="run-4",
event_type="message",
category="message",
content="after-delete",
)
]
)
assert later[0].seq > rows[2].seq
finally: finally:
await persistence.aclose() await persistence.aclose()
+5 -1
View File
@@ -777,6 +777,9 @@ dependencies = [
] ]
[package.optional-dependencies] [package.optional-dependencies]
mysql = [
{ name = "deerflow-storage", extra = ["mysql"] },
]
postgres = [ postgres = [
{ name = "deerflow-harness", extra = ["postgres"] }, { name = "deerflow-harness", extra = ["postgres"] },
{ name = "deerflow-storage", extra = ["postgres"] }, { name = "deerflow-storage", extra = ["postgres"] },
@@ -796,6 +799,7 @@ requires-dist = [
{ name = "deerflow-harness", editable = "packages/harness" }, { name = "deerflow-harness", editable = "packages/harness" },
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" }, { name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
{ name = "deerflow-storage", editable = "packages/storage" }, { name = "deerflow-storage", editable = "packages/storage" },
{ name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" },
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" }, { name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
{ name = "dingtalk-stream", specifier = ">=0.24.3" }, { name = "dingtalk-stream", specifier = ">=0.24.3" },
{ name = "email-validator", specifier = ">=2.0.0" }, { name = "email-validator", specifier = ">=2.0.0" },
@@ -812,7 +816,7 @@ requires-dist = [
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" }, { name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" }, { name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
] ]
provides-extras = ["postgres"] provides-extras = ["postgres", "mysql"]
[package.metadata.requires-dev] [package.metadata.requires-dev]
dev = [ dev = [