mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
feat(storage): add storage package base
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.config.storage_config import StorageConfig
|
||||
from store.persistence.factory import _create_database_url, storage_config_from_database_config
|
||||
|
||||
|
||||
def test_database_sqlite_config_maps_to_storage_config(tmp_path):
|
||||
database = SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=True,
|
||||
pool_size=9,
|
||||
)
|
||||
|
||||
storage = storage_config_from_database_config(database)
|
||||
|
||||
assert storage == StorageConfig(
|
||||
driver="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=True,
|
||||
pool_size=9,
|
||||
)
|
||||
assert storage.sqlite_storage_path == str(tmp_path / "deerflow.db")
|
||||
|
||||
|
||||
def test_database_memory_config_is_not_a_storage_backend():
|
||||
database = SimpleNamespace(backend="memory")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_database_postgres_config_preserves_url_and_pool_options():
|
||||
database = SimpleNamespace(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql://user:pass@db.example:5544/deerflow",
|
||||
echo_sql=True,
|
||||
pool_size=11,
|
||||
)
|
||||
|
||||
storage = storage_config_from_database_config(database)
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert storage.driver == "postgres"
|
||||
assert storage.database_url == "postgresql://user:pass@db.example:5544/deerflow"
|
||||
assert storage.username == "user"
|
||||
assert storage.password == "pass"
|
||||
assert storage.host == "db.example"
|
||||
assert storage.port == 5544
|
||||
assert storage.db_name == "deerflow"
|
||||
assert storage.echo_sql is True
|
||||
assert storage.pool_size == 11
|
||||
assert url.drivername == "postgresql+asyncpg"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_database_postgres_requires_url():
|
||||
database = SimpleNamespace(backend="postgres", postgres_url="")
|
||||
|
||||
with pytest.raises(ValueError, match="database.postgres_url is required"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_unsupported_database_backend_rejected():
|
||||
database = SimpleNamespace(backend="oracle")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import UserCreate, build_user_repository
|
||||
|
||||
|
||||
def test_sqlite_persistence_from_database_config_creates_storage_tables(tmp_path):
|
||||
async def run() -> None:
|
||||
persistence = await create_persistence_from_database_config(
|
||||
SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=False,
|
||||
pool_size=5,
|
||||
)
|
||||
)
|
||||
assert persistence is not None
|
||||
try:
|
||||
await persistence.setup()
|
||||
|
||||
async with persistence.engine.connect() as conn:
|
||||
tables = await conn.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names()))
|
||||
|
||||
assert {
|
||||
"users",
|
||||
"runs",
|
||||
"run_events",
|
||||
"threads_meta",
|
||||
"feedback",
|
||||
}.issubset(tables)
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email="storage-user@example.com",
|
||||
password_hash="hash",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
assert await repo.get_user_by_id(user.id) == user
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
asyncio.run(run())
|
||||
@@ -0,0 +1,312 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import (
|
||||
FeedbackCreate,
|
||||
RunCreate,
|
||||
RunEventCreate,
|
||||
ThreadMetaCreate,
|
||||
build_feedback_repository,
|
||||
build_run_event_repository,
|
||||
build_run_repository,
|
||||
build_thread_meta_repository,
|
||||
)
|
||||
|
||||
|
||||
async def _make_persistence(tmp_path):
|
||||
persistence = await create_persistence_from_database_config(
|
||||
SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=False,
|
||||
pool_size=5,
|
||||
)
|
||||
)
|
||||
await persistence.setup()
|
||||
return persistence
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_run_repository_filters_and_aggregates(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
old = datetime.now(UTC) - timedelta(hours=1)
|
||||
newer = datetime.now(UTC)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id="run-old",
|
||||
thread_id="thread-1",
|
||||
user_id="alice",
|
||||
status="pending",
|
||||
model_name="model-a",
|
||||
metadata={"kind": "draft"},
|
||||
kwargs={"temperature": 0.2},
|
||||
created_time=old,
|
||||
)
|
||||
)
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id="run-new",
|
||||
thread_id="thread-1",
|
||||
user_id="bob",
|
||||
status="running",
|
||||
model_name="model-b",
|
||||
error="queued",
|
||||
created_time=newer,
|
||||
)
|
||||
)
|
||||
await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running"))
|
||||
await repo.update_run_completion(
|
||||
"run-old",
|
||||
status="success",
|
||||
total_input_tokens=7,
|
||||
total_output_tokens=3,
|
||||
total_tokens=10,
|
||||
llm_call_count=1,
|
||||
lead_agent_tokens=8,
|
||||
subagent_tokens=2,
|
||||
first_human_message="hello",
|
||||
last_ai_message="world",
|
||||
)
|
||||
await repo.update_run_completion(
|
||||
"run-new",
|
||||
status="error",
|
||||
total_tokens=5,
|
||||
middleware_tokens=5,
|
||||
error="failed",
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
fetched = await repo.get_run("run-old")
|
||||
assert fetched is not None
|
||||
assert fetched.metadata == {"kind": "draft"}
|
||||
assert fetched.kwargs == {"temperature": 0.2}
|
||||
assert fetched.first_human_message == "hello"
|
||||
assert fetched.last_ai_message == "world"
|
||||
|
||||
all_thread_runs = await repo.list_runs_by_thread("thread-1")
|
||||
assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"]
|
||||
alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice")
|
||||
assert [run.run_id for run in alice_runs] == ["run-old"]
|
||||
|
||||
pending = await repo.list_pending(before=datetime.now(UTC).isoformat())
|
||||
assert [run.run_id for run in pending] == []
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("thread-1")
|
||||
assert agg["total_tokens"] == 15
|
||||
assert agg["total_input_tokens"] == 7
|
||||
assert agg["total_output_tokens"] == 3
|
||||
assert agg["total_runs"] == 2
|
||||
assert agg["by_model"] == {
|
||||
"model-a": {"tokens": 10, "runs": 1},
|
||||
"model-b": {"tokens": 5, "runs": 1},
|
||||
}
|
||||
assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5}
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id="thread-1",
|
||||
assistant_id="agent-a",
|
||||
user_id="alice",
|
||||
display_name="Initial",
|
||||
status="idle",
|
||||
metadata={"topic": "finance", "region": "cn"},
|
||||
)
|
||||
)
|
||||
await repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id="thread-2",
|
||||
assistant_id="agent-b",
|
||||
user_id="bob",
|
||||
status="running",
|
||||
metadata={"topic": "legal"},
|
||||
)
|
||||
)
|
||||
await repo.update_thread_meta(
|
||||
"thread-1",
|
||||
display_name="Updated",
|
||||
status="running",
|
||||
metadata={"topic": "finance", "region": "us"},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
fetched = await repo.get_thread_meta("thread-1")
|
||||
assert fetched is not None
|
||||
assert fetched.display_name == "Updated"
|
||||
assert fetched.status == "running"
|
||||
assert fetched.metadata == {"topic": "finance", "region": "us"}
|
||||
|
||||
by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice")
|
||||
assert [thread.thread_id for thread in by_metadata] == ["thread-1"]
|
||||
by_assistant = await repo.search_threads(assistant_id="agent-b")
|
||||
assert [thread.thread_id for thread in by_assistant] == ["thread-2"]
|
||||
|
||||
await repo.delete_thread("thread-1")
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
assert await repo.get_thread_meta("thread-1") is None
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
first = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-1",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=1,
|
||||
user_id="alice",
|
||||
message_id="msg-1",
|
||||
comment="good",
|
||||
)
|
||||
)
|
||||
second = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-2",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=-1,
|
||||
user_id="bob",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
assert await repo.get_feedback(first.feedback_id) == first
|
||||
assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [
|
||||
second.feedback_id,
|
||||
first.feedback_id,
|
||||
]
|
||||
assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == {
|
||||
"fb-1",
|
||||
"fb-2",
|
||||
}
|
||||
assert await repo.delete_feedback("fb-1") is True
|
||||
assert await repo.delete_feedback("missing") is False
|
||||
with pytest.raises(ValueError, match="rating must be"):
|
||||
await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-bad",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=0,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
assert await repo.get_feedback("fb-1") is None
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.append_batch(
|
||||
[
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
user_id="alice",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content={"role": "user", "content": "hello"},
|
||||
metadata={"source": "input"},
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
event_type="tool",
|
||||
category="debug",
|
||||
content="tool-call",
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-2",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="second",
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-2",
|
||||
run_id="run-3",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="other-thread",
|
||||
),
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert [(row.thread_id, row.seq) for row in rows] == [
|
||||
("thread-1", 1),
|
||||
("thread-1", 2),
|
||||
("thread-1", 3),
|
||||
("thread-2", 1),
|
||||
]
|
||||
assert rows[0].content == {"role": "user", "content": "hello"}
|
||||
assert rows[0].metadata == {"source": "input", "content_is_json": True}
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
messages = await repo.list_messages("thread-1", limit=2)
|
||||
assert [event.seq for event in messages] == [1, 3]
|
||||
assert await repo.count_messages("thread-1") == 2
|
||||
|
||||
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
|
||||
assert [event.seq for event in after] == [1]
|
||||
before = await repo.list_messages("thread-1", before_seq=3, limit=5)
|
||||
assert [event.seq for event in before] == [1]
|
||||
|
||||
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
|
||||
assert [event.content for event in events] == ["tool-call"]
|
||||
|
||||
assert await repo.delete_by_run("thread-1", "run-1") == 2
|
||||
assert await repo.delete_by_thread("thread-2") == 1
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
remaining = await repo.list_events("thread-1", "run-2")
|
||||
assert [event.seq for event in remaining] == [3]
|
||||
assert await repo.count_messages("thread-2") == 0
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
@@ -0,0 +1,178 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from store.persistence import MappedBase
|
||||
from store.repositories import UserCreate, UserNotFoundError, build_user_repository
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]:
|
||||
db_path = tmp_path / "storage-users.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
try:
|
||||
yield async_sessionmaker(engine, expire_on_commit=False)
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def _create_user(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
*,
|
||||
email: str = "user@example.com",
|
||||
system_role: str = "user",
|
||||
oauth_provider: str | None = None,
|
||||
oauth_id: str | None = None,
|
||||
):
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email=email,
|
||||
password_hash="hash",
|
||||
system_role=system_role, # type: ignore[arg-type]
|
||||
oauth_provider=oauth_provider,
|
||||
oauth_id=oauth_id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def test_create_and_get_user_by_id_and_email(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
created = await _create_user(session_factory)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
by_id = await repo.get_user_by_id(created.id)
|
||||
by_email = await repo.get_user_by_email(created.email)
|
||||
|
||||
assert by_id == created
|
||||
assert by_email == created
|
||||
assert created.system_role == "user"
|
||||
assert created.needs_setup is False
|
||||
assert created.token_version == 0
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_duplicate_email_raises_value_error(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="dupe@example.com")
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
with pytest.raises(ValueError, match="Email already registered"):
|
||||
await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email="dupe@example.com",
|
||||
password_hash="hash",
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_oauth_lookup_and_plain_users_without_oauth(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="local-1@example.com")
|
||||
await _create_user(session_factory, email="local-2@example.com")
|
||||
oauth_user = await _create_user(
|
||||
session_factory,
|
||||
email="oauth@example.com",
|
||||
oauth_provider="github",
|
||||
oauth_id="gh-123",
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
assert await repo.count_users() == 3
|
||||
assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user
|
||||
assert await repo.get_user_by_oauth("github", "missing") is None
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_count_admins_and_get_first_admin(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="user@example.com")
|
||||
admin = await _create_user(
|
||||
session_factory,
|
||||
email="admin@example.com",
|
||||
system_role="admin",
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
assert await repo.count_users() == 2
|
||||
assert await repo.count_admin_users() == 1
|
||||
assert await repo.get_first_admin() == admin
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_update_user_round_trips_token_version_and_setup_state(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
created = await _create_user(session_factory)
|
||||
updated = created.model_copy(
|
||||
update={
|
||||
"email": "renamed@example.com",
|
||||
"token_version": 4,
|
||||
"needs_setup": True,
|
||||
}
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
saved = await repo.update_user(updated)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
fetched = await repo.get_user_by_id(created.id)
|
||||
|
||||
assert saved.email == "renamed@example.com"
|
||||
assert fetched == updated
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_update_missing_user_raises(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
missing = UserCreate(id=str(uuid4()), email="missing@example.com")
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
created_shape = await repo.create_user(missing)
|
||||
await session.rollback()
|
||||
|
||||
with pytest.raises(UserNotFoundError):
|
||||
await repo.update_user(created_shape)
|
||||
|
||||
asyncio.run(run())
|
||||
Reference in New Issue
Block a user