fix(storage): address repository review feedback

This commit is contained in:
rayhpeng
2026-05-13 12:51:45 +08:00
parent d3066a1746
commit 11a9041b65
13 changed files with 140 additions and 65 deletions
@@ -1,6 +1,8 @@
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
from types import SimpleNamespace
@@ -62,6 +64,30 @@ def test_database_postgres_config_preserves_url_and_pool_options():
assert url.database == "deerflow"
def test_mysql_database_url_is_normalized_to_async_driver():
storage = StorageConfig(
driver="mysql",
database_url="mysql://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+aiomysql"
assert url.database == "deerflow"
def test_mysql_async_database_url_is_preserved():
storage = StorageConfig(
driver="mysql",
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+asyncmy"
assert url.database == "deerflow"
def test_database_postgres_requires_url():
database = SimpleNamespace(backend="postgres", postgres_url="")
@@ -74,3 +100,25 @@ def test_unsupported_database_backend_rejected():
with pytest.raises(ValueError, match="Unsupported database backend"):
storage_config_from_database_config(database)
def test_storage_models_import_without_config_file(tmp_path):
env = os.environ.copy()
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
result = subprocess.run(
[
sys.executable,
"-c",
"from store.persistence.base_model import UniversalText, id_key; "
"from store.repositories.models import RunEvent; "
"print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
],
check=False,
capture_output=True,
env=env,
text=True,
)
assert result.returncode == 0, result.stderr
assert "UniversalText run_events" in result.stdout
+22 -11
View File
@@ -348,25 +348,23 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
)
await session.commit()
assert [(row.thread_id, row.seq) for row in rows] == [
("thread-1", 1),
("thread-1", 2),
("thread-1", 3),
("thread-2", 1),
]
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
assert rows[1].seq == rows[0].seq + 1
assert rows[2].seq == rows[1].seq + 1
assert rows[0].content == {"role": "user", "content": "hello"}
assert rows[0].metadata == {"source": "input", "content_is_json": True}
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
messages = await repo.list_messages("thread-1", limit=2)
assert [event.seq for event in messages] == [1, 3]
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
assert await repo.count_messages("thread-1") == 2
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
assert [event.seq for event in after] == [1]
before = await repo.list_messages("thread-1", before_seq=3, limit=5)
assert [event.seq for event in before] == [1]
assert [event.seq for event in after] == [rows[0].seq]
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
assert [event.seq for event in before] == [rows[0].seq]
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
assert [event.content for event in events] == ["tool-call"]
@@ -378,7 +376,20 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
remaining = await repo.list_events("thread-1", "run-2")
assert [event.seq for event in remaining] == [3]
assert [event.seq for event in remaining] == [rows[2].seq]
assert await repo.count_messages("thread-2") == 0
later = await repo.append_batch(
[
RunEventCreate(
thread_id="thread-1",
run_id="run-4",
event_type="message",
category="message",
content="after-delete",
)
]
)
assert later[0].seq > rows[2].seq
finally:
await persistence.aclose()