mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 00:45:57 +00:00
fix(storage): address repository review feedback
This commit is contained in:
@@ -6,20 +6,17 @@ from sqlalchemy.dialects.mysql import LONGTEXT
|
|||||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
||||||
|
|
||||||
from store.common import DataBaseType
|
|
||||||
from store.config.app_config import get_app_config
|
|
||||||
from store.utils import get_timezone
|
from store.utils import get_timezone
|
||||||
|
|
||||||
timezone = get_timezone()
|
|
||||||
app_config = get_app_config()
|
|
||||||
|
|
||||||
# SQLite autoincrement only works with INTEGER PRIMARY KEY (not BIGINT)
|
def current_time() -> datetime:
|
||||||
_id_type = Integer if app_config.storage.driver == DataBaseType.sqlite else BigInteger
|
return get_timezone().now()
|
||||||
|
|
||||||
|
|
||||||
id_key = Annotated[
|
id_key = Annotated[
|
||||||
int,
|
int,
|
||||||
mapped_column(
|
mapped_column(
|
||||||
_id_type,
|
BigInteger().with_variant(Integer, "sqlite"),
|
||||||
primary_key=True,
|
primary_key=True,
|
||||||
unique=True,
|
unique=True,
|
||||||
index=True,
|
index=True,
|
||||||
@@ -33,9 +30,14 @@ id_key = Annotated[
|
|||||||
class UniversalText(TypeDecorator[str]):
|
class UniversalText(TypeDecorator[str]):
|
||||||
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
||||||
|
|
||||||
impl = LONGTEXT if DataBaseType.mysql == app_config.storage.driver else Text
|
impl = Text
|
||||||
cache_ok = True
|
cache_ok = True
|
||||||
|
|
||||||
|
def load_dialect_impl(self, dialect): # noqa: ANN001
|
||||||
|
if dialect.name == "mysql":
|
||||||
|
return dialect.type_descriptor(LONGTEXT())
|
||||||
|
return dialect.type_descriptor(Text())
|
||||||
|
|
||||||
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||||
return value
|
return value
|
||||||
|
|
||||||
@@ -54,11 +56,13 @@ class TimeZone(TypeDecorator[datetime]):
|
|||||||
return datetime
|
return datetime
|
||||||
|
|
||||||
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||||
|
timezone = get_timezone()
|
||||||
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
||||||
value = timezone.from_datetime(value)
|
value = timezone.from_datetime(value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||||
|
timezone = get_timezone()
|
||||||
if value is not None and value.tzinfo is None:
|
if value is not None and value.tzinfo is None:
|
||||||
value = value.replace(tzinfo=timezone.tz_info)
|
value = value.replace(tzinfo=timezone.tz_info)
|
||||||
return value
|
return value
|
||||||
@@ -70,14 +74,14 @@ class DateTimeMixin(MappedAsDataclass):
|
|||||||
created_time: Mapped[datetime] = mapped_column(
|
created_time: Mapped[datetime] = mapped_column(
|
||||||
TimeZone,
|
TimeZone,
|
||||||
init=False,
|
init=False,
|
||||||
default_factory=timezone.now,
|
default_factory=current_time,
|
||||||
sort_order=999,
|
sort_order=999,
|
||||||
comment="Created at",
|
comment="Created at",
|
||||||
)
|
)
|
||||||
updated_time: Mapped[datetime | None] = mapped_column(
|
updated_time: Mapped[datetime | None] = mapped_column(
|
||||||
TimeZone,
|
TimeZone,
|
||||||
init=False,
|
init=False,
|
||||||
onupdate=timezone.now,
|
onupdate=current_time,
|
||||||
sort_order=999,
|
sort_order=999,
|
||||||
comment="Updated at",
|
comment="Updated at",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -70,6 +70,8 @@ def _create_database_url(storage_config: StorageConfig) -> URL:
|
|||||||
url = make_url(storage_config.database_url)
|
url = make_url(storage_config.database_url)
|
||||||
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
|
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
|
||||||
url = url.set(drivername="postgresql+asyncpg")
|
url = url.set(drivername="postgresql+asyncpg")
|
||||||
|
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
|
||||||
|
url = url.set(drivername="mysql+aiomysql")
|
||||||
else:
|
else:
|
||||||
url = URL.create(
|
url = URL.create(
|
||||||
drivername=driver,
|
drivername=driver,
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ class RunEvent(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class RunEventRepositoryProtocol(Protocol):
|
class RunEventRepositoryProtocol(Protocol):
|
||||||
|
# Sequence values are time-ordered integer cursors. The application layer
|
||||||
|
# owns the single-writer invariant for a thread while a run is active.
|
||||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
|
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]: ...
|
||||||
|
|
||||||
async def list_messages(
|
async def list_messages(
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import secrets
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete, func, select
|
||||||
@@ -9,6 +12,26 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
||||||
from store.repositories.models.run_event import RunEvent as RunEventModel
|
from store.repositories.models.run_event import RunEvent as RunEventModel
|
||||||
|
|
||||||
|
_SEQ_COUNTER_BITS = 12
|
||||||
|
_SEQ_PROCESS_BITS = 9
|
||||||
|
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
|
||||||
|
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
|
||||||
|
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
|
||||||
|
_last_seq_millis = 0
|
||||||
|
_seq_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def _allocate_sequence_base(batch_size: int) -> int:
|
||||||
|
if batch_size >= _SEQ_COUNTER_LIMIT:
|
||||||
|
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
|
||||||
|
|
||||||
|
global _last_seq_millis
|
||||||
|
now_ms = time.time_ns() // 1_000_000
|
||||||
|
with _seq_lock:
|
||||||
|
seq_ms = max(now_ms, _last_seq_millis + 1)
|
||||||
|
_last_seq_millis = seq_ms
|
||||||
|
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
|
||||||
|
|
||||||
|
|
||||||
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||||
if not isinstance(content, str):
|
if not isinstance(content, str):
|
||||||
@@ -52,28 +75,17 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
|||||||
if not events:
|
if not events:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
thread_ids = {event.thread_id for event in events}
|
seq_base = _allocate_sequence_base(len(events))
|
||||||
seq_by_thread: dict[str, int] = {}
|
|
||||||
for thread_id in thread_ids:
|
|
||||||
max_seq = await self._session.scalar(
|
|
||||||
select(RunEventModel.seq)
|
|
||||||
.where(RunEventModel.thread_id == thread_id)
|
|
||||||
.order_by(RunEventModel.seq.desc())
|
|
||||||
.limit(1)
|
|
||||||
.with_for_update()
|
|
||||||
)
|
|
||||||
seq_by_thread[thread_id] = max_seq or 0
|
|
||||||
|
|
||||||
rows: list[RunEventModel] = []
|
rows: list[RunEventModel] = []
|
||||||
|
|
||||||
for event in events:
|
for index, event in enumerate(events, start=1):
|
||||||
seq_by_thread[event.thread_id] += 1
|
|
||||||
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
||||||
row = RunEventModel(
|
row = RunEventModel(
|
||||||
thread_id=event.thread_id,
|
thread_id=event.thread_id,
|
||||||
run_id=event.run_id,
|
run_id=event.run_id,
|
||||||
user_id=event.user_id,
|
user_id=event.user_id,
|
||||||
seq=seq_by_thread[event.thread_id],
|
seq=seq_base + index,
|
||||||
event_type=event.event_type,
|
event_type=event.event_type,
|
||||||
category=event.category,
|
category=event.category,
|
||||||
content=content,
|
content=content,
|
||||||
|
|||||||
@@ -5,10 +5,7 @@ from datetime import datetime
|
|||||||
from sqlalchemy import Integer, String, UniqueConstraint
|
from sqlalchemy import Integer, String, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
|
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||||
from store.utils import get_timezone
|
|
||||||
|
|
||||||
_tz = get_timezone()
|
|
||||||
|
|
||||||
|
|
||||||
class Feedback(DataClassBase):
|
class Feedback(DataClassBase):
|
||||||
@@ -33,7 +30,7 @@ class Feedback(DataClassBase):
|
|||||||
"created_at",
|
"created_at",
|
||||||
TimeZone,
|
TimeZone,
|
||||||
init=False,
|
init=False,
|
||||||
default_factory=_tz.now,
|
default_factory=current_time,
|
||||||
sort_order=999,
|
sort_order=999,
|
||||||
comment="Created at",
|
comment="Created at",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,10 +6,7 @@ from typing import Any
|
|||||||
from sqlalchemy import JSON, Index, Integer, String
|
from sqlalchemy import JSON, Index, Integer, String
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText
|
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||||
from store.utils import get_timezone
|
|
||||||
|
|
||||||
_tz = get_timezone()
|
|
||||||
|
|
||||||
|
|
||||||
class Run(DataClassBase):
|
class Run(DataClassBase):
|
||||||
@@ -51,7 +48,7 @@ class Run(DataClassBase):
|
|||||||
"created_at",
|
"created_at",
|
||||||
TimeZone,
|
TimeZone,
|
||||||
init=False,
|
init=False,
|
||||||
default_factory=_tz.now,
|
default_factory=current_time,
|
||||||
sort_order=999,
|
sort_order=999,
|
||||||
comment="Created at",
|
comment="Created at",
|
||||||
)
|
)
|
||||||
@@ -60,7 +57,7 @@ class Run(DataClassBase):
|
|||||||
TimeZone,
|
TimeZone,
|
||||||
init=False,
|
init=False,
|
||||||
default=None,
|
default=None,
|
||||||
onupdate=_tz.now,
|
onupdate=current_time,
|
||||||
sort_order=999,
|
sort_order=999,
|
||||||
comment="Updated at",
|
comment="Updated at",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -3,13 +3,16 @@ from __future__ import annotations
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import JSON, Index, Integer, String, UniqueConstraint
|
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, id_key
|
from store.persistence.base_model import (
|
||||||
from store.utils import get_timezone
|
DataClassBase,
|
||||||
|
TimeZone,
|
||||||
_tz = get_timezone()
|
UniversalText,
|
||||||
|
current_time,
|
||||||
|
id_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RunEvent(DataClassBase):
|
class RunEvent(DataClassBase):
|
||||||
@@ -31,13 +34,13 @@ class RunEvent(DataClassBase):
|
|||||||
category: Mapped[str] = mapped_column(String(16), index=True)
|
category: Mapped[str] = mapped_column(String(16), index=True)
|
||||||
|
|
||||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||||
seq: Mapped[int] = mapped_column(Integer, default=0, index=True)
|
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
|
||||||
content: Mapped[str] = mapped_column(UniversalText, default="")
|
content: Mapped[str] = mapped_column(UniversalText, default="")
|
||||||
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
|
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
TimeZone,
|
TimeZone,
|
||||||
init=False,
|
init=False,
|
||||||
default_factory=_tz.now,
|
default_factory=current_time,
|
||||||
sort_order=999,
|
sort_order=999,
|
||||||
comment="Event timestamp",
|
comment="Event timestamp",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -6,10 +6,7 @@ from typing import Any
|
|||||||
from sqlalchemy import JSON, String
|
from sqlalchemy import JSON, String
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from store.persistence.base_model import DataClassBase, TimeZone
|
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||||
from store.utils import get_timezone
|
|
||||||
|
|
||||||
_tz = get_timezone()
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadMeta(DataClassBase):
|
class ThreadMeta(DataClassBase):
|
||||||
@@ -31,7 +28,7 @@ class ThreadMeta(DataClassBase):
|
|||||||
"created_at",
|
"created_at",
|
||||||
TimeZone,
|
TimeZone,
|
||||||
init=False,
|
init=False,
|
||||||
default_factory=_tz.now,
|
default_factory=current_time,
|
||||||
sort_order=999,
|
sort_order=999,
|
||||||
comment="Created at",
|
comment="Created at",
|
||||||
)
|
)
|
||||||
@@ -40,7 +37,7 @@ class ThreadMeta(DataClassBase):
|
|||||||
TimeZone,
|
TimeZone,
|
||||||
init=False,
|
init=False,
|
||||||
default=None,
|
default=None,
|
||||||
onupdate=_tz.now,
|
onupdate=current_time,
|
||||||
sort_order=999,
|
sort_order=999,
|
||||||
comment="Updated at",
|
comment="Updated at",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -5,10 +5,7 @@ from datetime import datetime
|
|||||||
from sqlalchemy import Boolean, Index, String, text
|
from sqlalchemy import Boolean, Index, String, text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from store.persistence.base_model import DataClassBase, TimeZone
|
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||||
from store.utils import get_timezone
|
|
||||||
|
|
||||||
_tz = get_timezone()
|
|
||||||
|
|
||||||
|
|
||||||
class User(DataClassBase):
|
class User(DataClassBase):
|
||||||
@@ -39,7 +36,7 @@ class User(DataClassBase):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
TimeZone,
|
TimeZone,
|
||||||
init=False,
|
init=False,
|
||||||
default_factory=_tz.now,
|
default_factory=current_time,
|
||||||
sort_order=999,
|
sort_order=999,
|
||||||
comment="Created at",
|
comment="Created at",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ dependencies = [
|
|||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
|
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
|
||||||
|
mysql = ["deerflow-storage[mysql]"]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
@@ -62,6 +64,30 @@ def test_database_postgres_config_preserves_url_and_pool_options():
|
|||||||
assert url.database == "deerflow"
|
assert url.database == "deerflow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mysql_database_url_is_normalized_to_async_driver():
|
||||||
|
storage = StorageConfig(
|
||||||
|
driver="mysql",
|
||||||
|
database_url="mysql://user:pass@db.example:3306/deerflow",
|
||||||
|
)
|
||||||
|
|
||||||
|
url = _create_database_url(storage)
|
||||||
|
|
||||||
|
assert url.drivername == "mysql+aiomysql"
|
||||||
|
assert url.database == "deerflow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mysql_async_database_url_is_preserved():
|
||||||
|
storage = StorageConfig(
|
||||||
|
driver="mysql",
|
||||||
|
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
|
||||||
|
)
|
||||||
|
|
||||||
|
url = _create_database_url(storage)
|
||||||
|
|
||||||
|
assert url.drivername == "mysql+asyncmy"
|
||||||
|
assert url.database == "deerflow"
|
||||||
|
|
||||||
|
|
||||||
def test_database_postgres_requires_url():
|
def test_database_postgres_requires_url():
|
||||||
database = SimpleNamespace(backend="postgres", postgres_url="")
|
database = SimpleNamespace(backend="postgres", postgres_url="")
|
||||||
|
|
||||||
@@ -74,3 +100,25 @@ def test_unsupported_database_backend_rejected():
|
|||||||
|
|
||||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||||
storage_config_from_database_config(database)
|
storage_config_from_database_config(database)
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_models_import_without_config_file(tmp_path):
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
"-c",
|
||||||
|
"from store.persistence.base_model import UniversalText, id_key; "
|
||||||
|
"from store.repositories.models import RunEvent; "
|
||||||
|
"print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
|
||||||
|
],
|
||||||
|
check=False,
|
||||||
|
capture_output=True,
|
||||||
|
env=env,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.returncode == 0, result.stderr
|
||||||
|
assert "UniversalText run_events" in result.stdout
|
||||||
|
|||||||
@@ -348,25 +348,23 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
|
|||||||
)
|
)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
assert [(row.thread_id, row.seq) for row in rows] == [
|
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
|
||||||
("thread-1", 1),
|
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
|
||||||
("thread-1", 2),
|
assert rows[1].seq == rows[0].seq + 1
|
||||||
("thread-1", 3),
|
assert rows[2].seq == rows[1].seq + 1
|
||||||
("thread-2", 1),
|
|
||||||
]
|
|
||||||
assert rows[0].content == {"role": "user", "content": "hello"}
|
assert rows[0].content == {"role": "user", "content": "hello"}
|
||||||
assert rows[0].metadata == {"source": "input", "content_is_json": True}
|
assert rows[0].metadata == {"source": "input", "content_is_json": True}
|
||||||
|
|
||||||
async with persistence.session_factory() as session:
|
async with persistence.session_factory() as session:
|
||||||
repo = build_run_event_repository(session)
|
repo = build_run_event_repository(session)
|
||||||
messages = await repo.list_messages("thread-1", limit=2)
|
messages = await repo.list_messages("thread-1", limit=2)
|
||||||
assert [event.seq for event in messages] == [1, 3]
|
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
|
||||||
assert await repo.count_messages("thread-1") == 2
|
assert await repo.count_messages("thread-1") == 2
|
||||||
|
|
||||||
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
|
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
|
||||||
assert [event.seq for event in after] == [1]
|
assert [event.seq for event in after] == [rows[0].seq]
|
||||||
before = await repo.list_messages("thread-1", before_seq=3, limit=5)
|
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
|
||||||
assert [event.seq for event in before] == [1]
|
assert [event.seq for event in before] == [rows[0].seq]
|
||||||
|
|
||||||
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
|
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
|
||||||
assert [event.content for event in events] == ["tool-call"]
|
assert [event.content for event in events] == ["tool-call"]
|
||||||
@@ -378,7 +376,20 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
|
|||||||
async with persistence.session_factory() as session:
|
async with persistence.session_factory() as session:
|
||||||
repo = build_run_event_repository(session)
|
repo = build_run_event_repository(session)
|
||||||
remaining = await repo.list_events("thread-1", "run-2")
|
remaining = await repo.list_events("thread-1", "run-2")
|
||||||
assert [event.seq for event in remaining] == [3]
|
assert [event.seq for event in remaining] == [rows[2].seq]
|
||||||
assert await repo.count_messages("thread-2") == 0
|
assert await repo.count_messages("thread-2") == 0
|
||||||
|
|
||||||
|
later = await repo.append_batch(
|
||||||
|
[
|
||||||
|
RunEventCreate(
|
||||||
|
thread_id="thread-1",
|
||||||
|
run_id="run-4",
|
||||||
|
event_type="message",
|
||||||
|
category="message",
|
||||||
|
content="after-delete",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert later[0].seq > rows[2].seq
|
||||||
finally:
|
finally:
|
||||||
await persistence.aclose()
|
await persistence.aclose()
|
||||||
|
|||||||
Generated
+5
-1
@@ -777,6 +777,9 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
|
mysql = [
|
||||||
|
{ name = "deerflow-storage", extra = ["mysql"] },
|
||||||
|
]
|
||||||
postgres = [
|
postgres = [
|
||||||
{ name = "deerflow-harness", extra = ["postgres"] },
|
{ name = "deerflow-harness", extra = ["postgres"] },
|
||||||
{ name = "deerflow-storage", extra = ["postgres"] },
|
{ name = "deerflow-storage", extra = ["postgres"] },
|
||||||
@@ -796,6 +799,7 @@ requires-dist = [
|
|||||||
{ name = "deerflow-harness", editable = "packages/harness" },
|
{ name = "deerflow-harness", editable = "packages/harness" },
|
||||||
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
||||||
{ name = "deerflow-storage", editable = "packages/storage" },
|
{ name = "deerflow-storage", editable = "packages/storage" },
|
||||||
|
{ name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" },
|
||||||
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
|
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
|
||||||
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
||||||
{ name = "email-validator", specifier = ">=2.0.0" },
|
{ name = "email-validator", specifier = ">=2.0.0" },
|
||||||
@@ -812,7 +816,7 @@ requires-dist = [
|
|||||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
||||||
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
||||||
]
|
]
|
||||||
provides-extras = ["postgres"]
|
provides-extras = ["postgres", "mysql"]
|
||||||
|
|
||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
dev = [
|
dev = [
|
||||||
|
|||||||
Reference in New Issue
Block a user