feat(storage): add storage package base

This commit is contained in:
rayhpeng
2026-05-12 19:08:37 +08:00
parent 20d2d2b373
commit 485f8a2bf2
45 changed files with 3199 additions and 2 deletions
@@ -0,0 +1,76 @@
from __future__ import annotations
import os
from pathlib import Path
from types import SimpleNamespace
import pytest
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from store.config.storage_config import StorageConfig
from store.persistence.factory import _create_database_url, storage_config_from_database_config
def test_database_sqlite_config_maps_to_storage_config(tmp_path):
database = SimpleNamespace(
backend="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=True,
pool_size=9,
)
storage = storage_config_from_database_config(database)
assert storage == StorageConfig(
driver="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=True,
pool_size=9,
)
assert storage.sqlite_storage_path == str(tmp_path / "deerflow.db")
def test_database_memory_config_is_not_a_storage_backend():
database = SimpleNamespace(backend="memory")
with pytest.raises(ValueError, match="Unsupported database backend"):
storage_config_from_database_config(database)
def test_database_postgres_config_preserves_url_and_pool_options():
database = SimpleNamespace(
backend="postgres",
postgres_url="postgresql://user:pass@db.example:5544/deerflow",
echo_sql=True,
pool_size=11,
)
storage = storage_config_from_database_config(database)
url = _create_database_url(storage)
assert storage.driver == "postgres"
assert storage.database_url == "postgresql://user:pass@db.example:5544/deerflow"
assert storage.username == "user"
assert storage.password == "pass"
assert storage.host == "db.example"
assert storage.port == 5544
assert storage.db_name == "deerflow"
assert storage.echo_sql is True
assert storage.pool_size == 11
assert url.drivername == "postgresql+asyncpg"
assert url.database == "deerflow"
def test_database_postgres_requires_url():
database = SimpleNamespace(backend="postgres", postgres_url="")
with pytest.raises(ValueError, match="database.postgres_url is required"):
storage_config_from_database_config(database)
def test_unsupported_database_backend_rejected():
database = SimpleNamespace(backend="oracle")
with pytest.raises(ValueError, match="Unsupported database backend"):
storage_config_from_database_config(database)
@@ -0,0 +1,58 @@
from __future__ import annotations
import asyncio
import os
from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from sqlalchemy import inspect
from store.persistence import create_persistence_from_database_config
from store.repositories import UserCreate, build_user_repository
def test_sqlite_persistence_from_database_config_creates_storage_tables(tmp_path):
async def run() -> None:
persistence = await create_persistence_from_database_config(
SimpleNamespace(
backend="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=False,
pool_size=5,
)
)
assert persistence is not None
try:
await persistence.setup()
async with persistence.engine.connect() as conn:
tables = await conn.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names()))
assert {
"users",
"runs",
"run_events",
"threads_meta",
"feedback",
}.issubset(tables)
async with persistence.session_factory() as session:
repo = build_user_repository(session)
user = await repo.create_user(
UserCreate(
id=str(uuid4()),
email="storage-user@example.com",
password_hash="hash",
)
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_user_repository(session)
assert await repo.get_user_by_id(user.id) == user
finally:
await persistence.aclose()
asyncio.run(run())
+312
View File
@@ -0,0 +1,312 @@
from __future__ import annotations
import os
from datetime import UTC, datetime, timedelta
from pathlib import Path
from types import SimpleNamespace
import pytest
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from store.persistence import create_persistence_from_database_config
from store.repositories import (
FeedbackCreate,
RunCreate,
RunEventCreate,
ThreadMetaCreate,
build_feedback_repository,
build_run_event_repository,
build_run_repository,
build_thread_meta_repository,
)
async def _make_persistence(tmp_path):
persistence = await create_persistence_from_database_config(
SimpleNamespace(
backend="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=False,
pool_size=5,
)
)
await persistence.setup()
return persistence
@pytest.mark.anyio
async def test_storage_run_repository_filters_and_aggregates(tmp_path):
persistence = await _make_persistence(tmp_path)
old = datetime.now(UTC) - timedelta(hours=1)
newer = datetime.now(UTC)
try:
async with persistence.session_factory() as session:
repo = build_run_repository(session)
await repo.create_run(
RunCreate(
run_id="run-old",
thread_id="thread-1",
user_id="alice",
status="pending",
model_name="model-a",
metadata={"kind": "draft"},
kwargs={"temperature": 0.2},
created_time=old,
)
)
await repo.create_run(
RunCreate(
run_id="run-new",
thread_id="thread-1",
user_id="bob",
status="running",
model_name="model-b",
error="queued",
created_time=newer,
)
)
await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running"))
await repo.update_run_completion(
"run-old",
status="success",
total_input_tokens=7,
total_output_tokens=3,
total_tokens=10,
llm_call_count=1,
lead_agent_tokens=8,
subagent_tokens=2,
first_human_message="hello",
last_ai_message="world",
)
await repo.update_run_completion(
"run-new",
status="error",
total_tokens=5,
middleware_tokens=5,
error="failed",
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_run_repository(session)
fetched = await repo.get_run("run-old")
assert fetched is not None
assert fetched.metadata == {"kind": "draft"}
assert fetched.kwargs == {"temperature": 0.2}
assert fetched.first_human_message == "hello"
assert fetched.last_ai_message == "world"
all_thread_runs = await repo.list_runs_by_thread("thread-1")
assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"]
alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice")
assert [run.run_id for run in alice_runs] == ["run-old"]
pending = await repo.list_pending(before=datetime.now(UTC).isoformat())
assert [run.run_id for run in pending] == []
agg = await repo.aggregate_tokens_by_thread("thread-1")
assert agg["total_tokens"] == 15
assert agg["total_input_tokens"] == 7
assert agg["total_output_tokens"] == 3
assert agg["total_runs"] == 2
assert agg["by_model"] == {
"model-a": {"tokens": 10, "runs": 1},
"model-b": {"tokens": 5, "runs": 1},
}
assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5}
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
await repo.create_thread_meta(
ThreadMetaCreate(
thread_id="thread-1",
assistant_id="agent-a",
user_id="alice",
display_name="Initial",
status="idle",
metadata={"topic": "finance", "region": "cn"},
)
)
await repo.create_thread_meta(
ThreadMetaCreate(
thread_id="thread-2",
assistant_id="agent-b",
user_id="bob",
status="running",
metadata={"topic": "legal"},
)
)
await repo.update_thread_meta(
"thread-1",
display_name="Updated",
status="running",
metadata={"topic": "finance", "region": "us"},
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
fetched = await repo.get_thread_meta("thread-1")
assert fetched is not None
assert fetched.display_name == "Updated"
assert fetched.status == "running"
assert fetched.metadata == {"topic": "finance", "region": "us"}
by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice")
assert [thread.thread_id for thread in by_metadata] == ["thread-1"]
by_assistant = await repo.search_threads(assistant_id="agent-b")
assert [thread.thread_id for thread in by_assistant] == ["thread-2"]
await repo.delete_thread("thread-1")
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
assert await repo.get_thread_meta("thread-1") is None
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_feedback_repository(session)
first = await repo.create_feedback(
FeedbackCreate(
feedback_id="fb-1",
run_id="run-1",
thread_id="thread-1",
rating=1,
user_id="alice",
message_id="msg-1",
comment="good",
)
)
second = await repo.create_feedback(
FeedbackCreate(
feedback_id="fb-2",
run_id="run-1",
thread_id="thread-1",
rating=-1,
user_id="bob",
)
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_feedback_repository(session)
assert await repo.get_feedback(first.feedback_id) == first
assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [
second.feedback_id,
first.feedback_id,
]
assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == {
"fb-1",
"fb-2",
}
assert await repo.delete_feedback("fb-1") is True
assert await repo.delete_feedback("missing") is False
with pytest.raises(ValueError, match="rating must be"):
await repo.create_feedback(
FeedbackCreate(
feedback_id="fb-bad",
run_id="run-1",
thread_id="thread-1",
rating=0,
)
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_feedback_repository(session)
assert await repo.get_feedback("fb-1") is None
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
rows = await repo.append_batch(
[
RunEventCreate(
thread_id="thread-1",
run_id="run-1",
user_id="alice",
event_type="message",
category="message",
content={"role": "user", "content": "hello"},
metadata={"source": "input"},
),
RunEventCreate(
thread_id="thread-1",
run_id="run-1",
event_type="tool",
category="debug",
content="tool-call",
),
RunEventCreate(
thread_id="thread-1",
run_id="run-2",
event_type="message",
category="message",
content="second",
),
RunEventCreate(
thread_id="thread-2",
run_id="run-3",
event_type="message",
category="message",
content="other-thread",
),
]
)
await session.commit()
assert [(row.thread_id, row.seq) for row in rows] == [
("thread-1", 1),
("thread-1", 2),
("thread-1", 3),
("thread-2", 1),
]
assert rows[0].content == {"role": "user", "content": "hello"}
assert rows[0].metadata == {"source": "input", "content_is_json": True}
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
messages = await repo.list_messages("thread-1", limit=2)
assert [event.seq for event in messages] == [1, 3]
assert await repo.count_messages("thread-1") == 2
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
assert [event.seq for event in after] == [1]
before = await repo.list_messages("thread-1", before_seq=3, limit=5)
assert [event.seq for event in before] == [1]
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
assert [event.content for event in events] == ["tool-call"]
assert await repo.delete_by_run("thread-1", "run-1") == 2
assert await repo.delete_by_thread("thread-2") == 1
await session.commit()
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
remaining = await repo.list_events("thread-1", "run-2")
assert [event.seq for event in remaining] == [3]
assert await repo.count_messages("thread-2") == 0
finally:
await persistence.aclose()
@@ -0,0 +1,178 @@
from __future__ import annotations
import asyncio
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from uuid import uuid4
import pytest
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
import store.repositories.models # noqa: F401
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.repositories import UserCreate, UserNotFoundError, build_user_repository
@asynccontextmanager
async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]:
db_path = tmp_path / "storage-users.db"
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
try:
yield async_sessionmaker(engine, expire_on_commit=False)
finally:
await engine.dispose()
async def _create_user(
session_factory: async_sessionmaker[AsyncSession],
*,
email: str = "user@example.com",
system_role: str = "user",
oauth_provider: str | None = None,
oauth_id: str | None = None,
):
async with session_factory() as session:
repo = build_user_repository(session)
user = await repo.create_user(
UserCreate(
id=str(uuid4()),
email=email,
password_hash="hash",
system_role=system_role, # type: ignore[arg-type]
oauth_provider=oauth_provider,
oauth_id=oauth_id,
)
)
await session.commit()
return user
def test_create_and_get_user_by_id_and_email(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
created = await _create_user(session_factory)
async with session_factory() as session:
repo = build_user_repository(session)
by_id = await repo.get_user_by_id(created.id)
by_email = await repo.get_user_by_email(created.email)
assert by_id == created
assert by_email == created
assert created.system_role == "user"
assert created.needs_setup is False
assert created.token_version == 0
asyncio.run(run())
def test_duplicate_email_raises_value_error(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
await _create_user(session_factory, email="dupe@example.com")
async with session_factory() as session:
repo = build_user_repository(session)
with pytest.raises(ValueError, match="Email already registered"):
await repo.create_user(
UserCreate(
id=str(uuid4()),
email="dupe@example.com",
password_hash="hash",
)
)
asyncio.run(run())
def test_oauth_lookup_and_plain_users_without_oauth(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
await _create_user(session_factory, email="local-1@example.com")
await _create_user(session_factory, email="local-2@example.com")
oauth_user = await _create_user(
session_factory,
email="oauth@example.com",
oauth_provider="github",
oauth_id="gh-123",
)
async with session_factory() as session:
repo = build_user_repository(session)
assert await repo.count_users() == 3
assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user
assert await repo.get_user_by_oauth("github", "missing") is None
asyncio.run(run())
def test_count_admins_and_get_first_admin(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
await _create_user(session_factory, email="user@example.com")
admin = await _create_user(
session_factory,
email="admin@example.com",
system_role="admin",
)
async with session_factory() as session:
repo = build_user_repository(session)
assert await repo.count_users() == 2
assert await repo.count_admin_users() == 1
assert await repo.get_first_admin() == admin
asyncio.run(run())
def test_update_user_round_trips_token_version_and_setup_state(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
created = await _create_user(session_factory)
updated = created.model_copy(
update={
"email": "renamed@example.com",
"token_version": 4,
"needs_setup": True,
}
)
async with session_factory() as session:
repo = build_user_repository(session)
saved = await repo.update_user(updated)
await session.commit()
async with session_factory() as session:
repo = build_user_repository(session)
fetched = await repo.get_user_by_id(created.id)
assert saved.email == "renamed@example.com"
assert fetched == updated
asyncio.run(run())
def test_update_missing_user_raises(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
missing = UserCreate(id=str(uuid4()), email="missing@example.com")
async with session_factory() as session:
repo = build_user_repository(session)
created_shape = await repo.create_user(missing)
await session.rollback()
with pytest.raises(UserNotFoundError):
await repo.update_user(created_shape)
asyncio.run(run())