mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
fix(storage): address repository review feedback
This commit is contained in:
@@ -1,6 +1,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
@@ -62,6 +64,30 @@ def test_database_postgres_config_preserves_url_and_pool_options():
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_mysql_database_url_is_normalized_to_async_driver():
|
||||
storage = StorageConfig(
|
||||
driver="mysql",
|
||||
database_url="mysql://user:pass@db.example:3306/deerflow",
|
||||
)
|
||||
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert url.drivername == "mysql+aiomysql"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_mysql_async_database_url_is_preserved():
|
||||
storage = StorageConfig(
|
||||
driver="mysql",
|
||||
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
|
||||
)
|
||||
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert url.drivername == "mysql+asyncmy"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_database_postgres_requires_url():
|
||||
database = SimpleNamespace(backend="postgres", postgres_url="")
|
||||
|
||||
@@ -74,3 +100,25 @@ def test_unsupported_database_backend_rejected():
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_storage_models_import_without_config_file(tmp_path):
|
||||
env = os.environ.copy()
|
||||
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
"from store.persistence.base_model import UniversalText, id_key; "
|
||||
"from store.repositories.models import RunEvent; "
|
||||
"print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
|
||||
],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
env=env,
|
||||
text=True,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, result.stderr
|
||||
assert "UniversalText run_events" in result.stdout
|
||||
|
||||
@@ -348,25 +348,23 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert [(row.thread_id, row.seq) for row in rows] == [
|
||||
("thread-1", 1),
|
||||
("thread-1", 2),
|
||||
("thread-1", 3),
|
||||
("thread-2", 1),
|
||||
]
|
||||
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
|
||||
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
|
||||
assert rows[1].seq == rows[0].seq + 1
|
||||
assert rows[2].seq == rows[1].seq + 1
|
||||
assert rows[0].content == {"role": "user", "content": "hello"}
|
||||
assert rows[0].metadata == {"source": "input", "content_is_json": True}
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
messages = await repo.list_messages("thread-1", limit=2)
|
||||
assert [event.seq for event in messages] == [1, 3]
|
||||
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
|
||||
assert await repo.count_messages("thread-1") == 2
|
||||
|
||||
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
|
||||
assert [event.seq for event in after] == [1]
|
||||
before = await repo.list_messages("thread-1", before_seq=3, limit=5)
|
||||
assert [event.seq for event in before] == [1]
|
||||
assert [event.seq for event in after] == [rows[0].seq]
|
||||
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
|
||||
assert [event.seq for event in before] == [rows[0].seq]
|
||||
|
||||
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
|
||||
assert [event.content for event in events] == ["tool-call"]
|
||||
@@ -378,7 +376,20 @@ async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
remaining = await repo.list_events("thread-1", "run-2")
|
||||
assert [event.seq for event in remaining] == [3]
|
||||
assert [event.seq for event in remaining] == [rows[2].seq]
|
||||
assert await repo.count_messages("thread-2") == 0
|
||||
|
||||
later = await repo.append_batch(
|
||||
[
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-4",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="after-delete",
|
||||
)
|
||||
]
|
||||
)
|
||||
assert later[0].seq > rows[2].seq
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
Reference in New Issue
Block a user