diff --git a/backend/packages/storage/store/persistence/base_model.py b/backend/packages/storage/store/persistence/base_model.py index e60562020..ed9006718 100644 --- a/backend/packages/storage/store/persistence/base_model.py +++ b/backend/packages/storage/store/persistence/base_model.py @@ -6,20 +6,17 @@ from sqlalchemy.dialects.mysql import LONGTEXT from sqlalchemy.ext.asyncio import AsyncAttrs from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column -from store.common import DataBaseType -from store.config.app_config import get_app_config from store.utils import get_timezone -timezone = get_timezone() -app_config = get_app_config() -# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT) -_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger +def current_time() -> datetime: + return get_timezone().now() + id_key = Annotated[ int, mapped_column( - _id_type, + BigInteger().with_variant(Integer, "sqlite"), primary_key=True, unique=True, index=True, @@ -33,9 +30,14 @@ id_key = Annotated[ class UniversalText(TypeDecorator[str]): """Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL).""" - impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text + impl = Text cache_ok = True + def load_dialect_impl(self, dialect): # noqa: ANN001 + if dialect.name == "mysql": + return dialect.type_descriptor(LONGTEXT()) + return dialect.type_descriptor(Text()) + def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001 return value @@ -54,11 +56,13 @@ class TimeZone(TypeDecorator[datetime]): return datetime def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001 + timezone = get_timezone() if value is not None and value.utcoffset() != timezone.now().utcoffset(): value = timezone.from_datetime(value) return value def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001 + timezone = get_timezone() if value is not None and value.tzinfo is None: value = value.replace(tzinfo=timezone.tz_info) return value @@ -70,14 +74,14 @@ class DateTimeMixin(MappedAsDataclass): created_time: Mapped[datetime] = mapped_column( TimeZone, init=False, - default_factory=timezone.now, + default_factory=current_time, sort_order=999, comment="Created at", ) updated_time: Mapped[datetime | None] = mapped_column( TimeZone, init=False, - onupdate=timezone.now, + onupdate=current_time, sort_order=999, comment="Updated at", ) diff --git a/backend/packages/storage/store/persistence/factory.py b/backend/packages/storage/store/persistence/factory.py index 30de5230e..73fd6b113 100644 --- a/backend/packages/storage/store/persistence/factory.py +++ b/backend/packages/storage/store/persistence/factory.py @@ -70,6 +70,8 @@ def _create_database_url(storage_config: StorageConfig) -> URL: url = make_url(storage_config.database_url) if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql": url = url.set(drivername="postgresql+asyncpg") + elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql": + url = url.set(drivername="mysql+aiomysql") else: url = URL.create( drivername=driver, diff --git a/backend/packages/storage/store/repositories/contracts/run_event.py b/backend/packages/storage/store/repositories/contracts/run_event.py index 1f0960337..d0cb11aa3 100644 --- a/backend/packages/storage/store/repositories/contracts/run_event.py +++ b/backend/packages/storage/store/repositories/contracts/run_event.py @@ -34,6 +34,8 @@ class RunEvent(BaseModel): class RunEventRepositoryProtocol(Protocol): + # Sequence values are time-ordered integer cursors. The application layer + # owns the single-writer invariant for a thread while a run is active. async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ... async def list_messages( diff --git a/backend/packages/storage/store/repositories/db/run_event.py b/backend/packages/storage/store/repositories/db/run_event.py index 9d22985c0..a8c312e8f 100644 --- a/backend/packages/storage/store/repositories/db/run_event.py +++ b/backend/packages/storage/store/repositories/db/run_event.py @@ -1,6 +1,9 @@ from __future__ import annotations import json +import secrets +import threading +import time from typing import Any from sqlalchemy import delete, func, select @@ -9,6 +12,26 @@ from sqlalchemy.ext.asyncio import AsyncSession from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol from store.repositories.models.run_event import RunEvent as RunEventModel +_SEQ_COUNTER_BITS = 12 +_SEQ_PROCESS_BITS = 9 +_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS) +_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS +_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS +_last_seq_millis = 0 +_seq_lock = threading.Lock() + + +def _allocate_sequence_base(batch_size: int) -> int: + if batch_size >= _SEQ_COUNTER_LIMIT: + raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}") + + global _last_seq_millis + now_ms = time.time_ns() // 1_000_000 + with _seq_lock: + seq_ms = max(now_ms, _last_seq_millis + 1) + _last_seq_millis = seq_ms + return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS) + def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]: if not isinstance(content, str): @@ -52,28 +75,17 @@ class DbRunEventRepository(RunEventRepositoryProtocol): if not events: return [] - thread_ids = {event.thread_id for event in events} - seq_by_thread: dict[str, int] = {} - for thread_id in thread_ids: - max_seq = await self._session.scalar( - select(RunEventModel.seq) - .where(RunEventModel.thread_id == thread_id) - .order_by(RunEventModel.seq.desc()) - .limit(1) - .with_for_update() - ) - seq_by_thread[thread_id] = max_seq or 0 + seq_base = _allocate_sequence_base(len(events)) rows: list[RunEventModel] = [] - for event in events: - seq_by_thread[event.thread_id] += 1 + for index, event in enumerate(events, start=1): content, metadata = _serialize_content(event.content, dict(event.metadata)) row = RunEventModel( thread_id=event.thread_id, run_id=event.run_id, user_id=event.user_id, - seq=seq_by_thread[event.thread_id], + seq=seq_base + index, event_type=event.event_type, category=event.category, content=content, diff --git a/backend/packages/storage/store/repositories/models/feedback.py b/backend/packages/storage/store/repositories/models/feedback.py index 581a91bbc..7b07fe44b 100644 --- a/backend/packages/storage/store/repositories/models/feedback.py +++ b/backend/packages/storage/store/repositories/models/feedback.py @@ -5,10 +5,7 @@ from datetime import datetime from sqlalchemy import Integer, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column -from store.persistence.base_model import DataClassBase, TimeZone, UniversalText -from store.utils import get_timezone - -_tz = get_timezone() +from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time class Feedback(DataClassBase): @@ -33,7 +30,7 @@ class Feedback(DataClassBase): "created_at", TimeZone, init=False, - default_factory=_tz.now, + default_factory=current_time, sort_order=999, comment="Created at", ) diff --git a/backend/packages/storage/store/repositories/models/run.py b/backend/packages/storage/store/repositories/models/run.py index dd0f93b88..cc30c8d91 100644 --- a/backend/packages/storage/store/repositories/models/run.py +++ b/backend/packages/storage/store/repositories/models/run.py @@ -6,10 +6,7 @@ from typing import Any from sqlalchemy import JSON, Index, Integer, String from sqlalchemy.orm import Mapped, mapped_column -from store.persistence.base_model import DataClassBase, TimeZone, UniversalText -from store.utils import get_timezone - -_tz = get_timezone() +from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time class Run(DataClassBase): @@ -51,7 +48,7 @@ class Run(DataClassBase): "created_at", TimeZone, init=False, - default_factory=_tz.now, + default_factory=current_time, sort_order=999, comment="Created at", ) @@ -60,7 +57,7 @@ class Run(DataClassBase): TimeZone, init=False, default=None, - onupdate=_tz.now, + onupdate=current_time, sort_order=999, comment="Updated at", ) diff --git a/backend/packages/storage/store/repositories/models/run_event.py b/backend/packages/storage/store/repositories/models/run_event.py index b07665d1c..8651ad563 100644 --- a/backend/packages/storage/store/repositories/models/run_event.py +++ b/backend/packages/storage/store/repositories/models/run_event.py @@ -3,13 +3,16 @@ from __future__ import annotations from datetime import datetime from typing import Any -from sqlalchemy import JSON, Index, Integer, String, UniqueConstraint +from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column -from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key -from store.utils import get_timezone - -_tz = get_timezone() +from store.persistence.base_model import ( + DataClassBase, + TimeZone, + UniversalText, + current_time, + id_key, +) class RunEvent(DataClassBase): @@ -31,13 +34,13 @@ class RunEvent(DataClassBase): category: Mapped[str] = mapped_column(String(16), index=True) user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True) - seq: Mapped[int] = mapped_column(Integer, default=0, index=True) + seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True) content: Mapped[str] = mapped_column(UniversalText, default="") meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict) created_at: Mapped[datetime] = mapped_column( TimeZone, init=False, - default_factory=_tz.now, + default_factory=current_time, sort_order=999, comment="Event timestamp", ) diff --git a/backend/packages/storage/store/repositories/models/thread_meta.py b/backend/packages/storage/store/repositories/models/thread_meta.py index ce3e70f27..8ba1a7e73 100644 --- a/backend/packages/storage/store/repositories/models/thread_meta.py +++ b/backend/packages/storage/store/repositories/models/thread_meta.py @@ -6,10 +6,7 @@ from typing import Any from sqlalchemy import JSON, String from sqlalchemy.orm import Mapped, mapped_column -from store.persistence.base_model import DataClassBase, TimeZone -from store.utils import get_timezone - -_tz = get_timezone() +from store.persistence.base_model import DataClassBase, TimeZone, current_time class ThreadMeta(DataClassBase): @@ -31,7 +28,7 @@ class ThreadMeta(DataClassBase): "created_at", TimeZone, init=False, - default_factory=_tz.now, + default_factory=current_time, sort_order=999, comment="Created at", ) @@ -40,7 +37,7 @@ class ThreadMeta(DataClassBase): TimeZone, init=False, default=None, - onupdate=_tz.now, + onupdate=current_time, sort_order=999, comment="Updated at", ) diff --git a/backend/packages/storage/store/repositories/models/user.py b/backend/packages/storage/store/repositories/models/user.py index a017ec47e..e0a508189 100644 --- a/backend/packages/storage/store/repositories/models/user.py +++ b/backend/packages/storage/store/repositories/models/user.py @@ -5,10 +5,7 @@ from datetime import datetime from sqlalchemy import Boolean, Index, String, text from sqlalchemy.orm import Mapped, mapped_column -from store.persistence.base_model import DataClassBase, TimeZone -from store.utils import get_timezone - -_tz = get_timezone() +from store.persistence.base_model import DataClassBase, TimeZone, current_time class User(DataClassBase): @@ -39,7 +36,7 @@ class User(DataClassBase): created_at: Mapped[datetime] = mapped_column( TimeZone, init=False, - default_factory=_tz.now, + default_factory=current_time, sort_order=999, comment="Created at", ) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 6722b255f..b662e50c6 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -26,6 +26,7 @@ dependencies = [ [project.optional-dependencies] postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"] +mysql = ["deerflow-storage[mysql]"] [dependency-groups] dev = [ diff --git a/backend/tests/test_storage_persistence_config.py b/backend/tests/test_storage_persistence_config.py index 17b290296..b718d58fb 100644 --- a/backend/tests/test_storage_persistence_config.py +++ b/backend/tests/test_storage_persistence_config.py @@ -1,6 +1,8 @@ from __future__ import annotations import os +import subprocess +import sys from pathlib import Path from types import SimpleNamespace @@ -62,6 +64,30 @@ def test_database_postgres_config_preserves_url_and_pool_options(): assert url.database == "deerflow" +def test_mysql_database_url_is_normalized_to_async_driver(): + storage = StorageConfig( + driver="mysql", + database_url="mysql://user:pass@db.example:3306/deerflow", + ) + + url = _create_database_url(storage) + + assert url.drivername == "mysql+aiomysql" + assert url.database == "deerflow" + + +def test_mysql_async_database_url_is_preserved(): + storage = StorageConfig( + driver="mysql", + database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow", + ) + + url = _create_database_url(storage) + + assert url.drivername == "mysql+asyncmy" + assert url.database == "deerflow" + + def test_database_postgres_requires_url(): database = SimpleNamespace(backend="postgres", postgres_url="") @@ -74,3 +100,25 @@ def test_unsupported_database_backend_rejected(): with pytest.raises(ValueError, match="Unsupported database backend"): storage_config_from_database_config(database) + + +def test_storage_models_import_without_config_file(tmp_path): + env = os.environ.copy() + env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml") + + result = subprocess.run( + [ + sys.executable, + "-c", + "from store.persistence.base_model import UniversalText, id_key; " + "from store.repositories.models import RunEvent; " + "print(UniversalText.__name__, RunEvent.__tablename__, id_key)", + ], + check=False, + capture_output=True, + env=env, + text=True, + ) + + assert result.returncode == 0, result.stderr + assert "UniversalText run_events" in result.stdout diff --git a/backend/tests/test_storage_repositories.py b/backend/tests/test_storage_repositories.py index 0e66d934c..bfe0525a7 100644 --- a/backend/tests/test_storage_repositories.py +++ b/backend/tests/test_storage_repositories.py @@ -348,25 +348,23 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_ ) await session.commit() - assert [(row.thread_id, row.seq) for row in rows] == [ - ("thread-1", 1), - ("thread-1", 2), - ("thread-1", 3), - ("thread-2", 1), - ] + assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"] + assert [row.seq for row in rows] == sorted(row.seq for row in rows) + assert rows[1].seq == rows[0].seq + 1 + assert rows[2].seq == rows[1].seq + 1 assert rows[0].content == {"role": "user", "content": "hello"} assert rows[0].metadata == {"source": "input", "content_is_json": True} async with persistence.session_factory() as session: repo = build_run_event_repository(session) messages = await repo.list_messages("thread-1", limit=2) - assert [event.seq for event in messages] == [1, 3] + assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq] assert await repo.count_messages("thread-1") == 2 after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5) - assert [event.seq for event in after] == [1] - before = await repo.list_messages("thread-1", before_seq=3, limit=5) - assert [event.seq for event in before] == [1] + assert [event.seq for event in after] == [rows[0].seq] + before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5) + assert [event.seq for event in before] == [rows[0].seq] events = await repo.list_events("thread-1", "run-1", event_types=["tool"]) assert [event.content for event in events] == ["tool-call"] @@ -378,7 +376,20 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_ async with persistence.session_factory() as session: repo = build_run_event_repository(session) remaining = await repo.list_events("thread-1", "run-2") - assert [event.seq for event in remaining] == [3] + assert [event.seq for event in remaining] == [rows[2].seq] assert await repo.count_messages("thread-2") == 0 + + later = await repo.append_batch( + [ + RunEventCreate( + thread_id="thread-1", + run_id="run-4", + event_type="message", + category="message", + content="after-delete", + ) + ] + ) + assert later[0].seq > rows[2].seq finally: await persistence.aclose() diff --git a/backend/uv.lock b/backend/uv.lock index 53fcae74e..1dbda8051 100644 --- a/backend/uv.lock +++ b/backend/uv.lock @@ -777,6 +777,9 @@ dependencies = [ ] [package.optional-dependencies] +mysql = [ + { name = "deerflow-storage", extra = ["mysql"] }, +] postgres = [ { name = "deerflow-harness", extra = ["postgres"] }, { name = "deerflow-storage", extra = ["postgres"] }, @@ -796,6 +799,7 @@ requires-dist = [ { name = "deerflow-harness", editable = "packages/harness" }, { name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" }, { name = "deerflow-storage", editable = "packages/storage" }, + { name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" }, { name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" }, { name = "dingtalk-stream", specifier = ">=0.24.3" }, { name = "email-validator", specifier = ">=2.0.0" }, @@ -812,7 +816,7 @@ requires-dist = [ { name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" }, { name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" }, ] -provides-extras = ["postgres"] +provides-extras = ["postgres", "mysql"] [package.metadata.requires-dev] dev = [