from __future__ import annotations import os from datetime import UTC, datetime, timedelta from pathlib import Path from types import SimpleNamespace import pytest os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml")) from store.persistence import create_persistence_from_database_config from store.repositories import ( FeedbackCreate, RunCreate, RunEventCreate, ThreadMetaCreate, build_feedback_repository, build_run_event_repository, build_run_repository, build_thread_meta_repository, ) async def _make_persistence(tmp_path): persistence = await create_persistence_from_database_config( SimpleNamespace( backend="sqlite", sqlite_dir=str(tmp_path), echo_sql=False, pool_size=5, ) ) await persistence.setup() return persistence @pytest.mark.anyio async def test_storage_run_repository_filters_and_aggregates(tmp_path): persistence = await _make_persistence(tmp_path) old = datetime.now(UTC) - timedelta(hours=1) newer = datetime.now(UTC) try: async with persistence.session_factory() as session: repo = build_run_repository(session) await repo.create_run( RunCreate( run_id="run-old", thread_id="thread-1", user_id="alice", status="pending", model_name="model-a", metadata={"kind": "draft"}, kwargs={"temperature": 0.2}, created_time=old, ) ) await repo.create_run( RunCreate( run_id="run-new", thread_id="thread-1", user_id="bob", status="running", model_name="model-b", error="queued", created_time=newer, ) ) await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running")) await repo.update_run_completion( "run-old", status="success", total_input_tokens=7, total_output_tokens=3, total_tokens=10, llm_call_count=1, lead_agent_tokens=8, subagent_tokens=2, first_human_message="hello", last_ai_message="world", ) await repo.update_run_completion( "run-new", status="error", total_tokens=5, middleware_tokens=5, error="failed", ) await session.commit() async with persistence.session_factory() as session: repo = build_run_repository(session) fetched = await repo.get_run("run-old") assert fetched is not None assert fetched.metadata == {"kind": "draft"} assert fetched.kwargs == {"temperature": 0.2} assert fetched.first_human_message == "hello" assert fetched.last_ai_message == "world" all_thread_runs = await repo.list_runs_by_thread("thread-1") assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"] alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice") assert [run.run_id for run in alice_runs] == ["run-old"] pending = await repo.list_pending(before=datetime.now(UTC).isoformat()) assert [run.run_id for run in pending] == [] agg = await repo.aggregate_tokens_by_thread("thread-1") assert agg["total_tokens"] == 15 assert agg["total_input_tokens"] == 7 assert agg["total_output_tokens"] == 3 assert agg["total_runs"] == 2 assert agg["by_model"] == { "model-a": {"tokens": 10, "runs": 1}, "model-b": {"tokens": 5, "runs": 1}, } assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5} finally: await persistence.aclose() @pytest.mark.anyio async def test_storage_thread_meta_repository_search_update_delete(tmp_path): persistence = await _make_persistence(tmp_path) try: async with persistence.session_factory() as session: repo = build_thread_meta_repository(session) await repo.create_thread_meta( ThreadMetaCreate( thread_id="thread-1", assistant_id="agent-a", user_id="alice", display_name="Initial", status="idle", metadata={"topic": "finance", "region": "cn"}, ) ) await repo.create_thread_meta( ThreadMetaCreate( thread_id="thread-2", assistant_id="agent-b", user_id="bob", status="running", metadata={"topic": "legal"}, ) ) await repo.update_thread_meta( "thread-1", display_name="Updated", status="running", metadata={"topic": "finance", "region": "us"}, ) await session.commit() async with persistence.session_factory() as session: repo = build_thread_meta_repository(session) fetched = await repo.get_thread_meta("thread-1") assert fetched is not None assert fetched.display_name == "Updated" assert fetched.status == "running" assert fetched.metadata == {"topic": "finance", "region": "us"} by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice") assert [thread.thread_id for thread in by_metadata] == ["thread-1"] by_assistant = await repo.search_threads(assistant_id="agent-b") assert [thread.thread_id for thread in by_assistant] == ["thread-2"] await repo.delete_thread("thread-1") await session.commit() async with persistence.session_factory() as session: repo = build_thread_meta_repository(session) assert await repo.get_thread_meta("thread-1") is None finally: await persistence.aclose() @pytest.mark.anyio async def test_storage_feedback_repository_lists_and_deletes(tmp_path): persistence = await _make_persistence(tmp_path) try: async with persistence.session_factory() as session: repo = build_feedback_repository(session) first = await repo.create_feedback( FeedbackCreate( feedback_id="fb-1", run_id="run-1", thread_id="thread-1", rating=1, user_id="alice", message_id="msg-1", comment="good", ) ) second = await repo.create_feedback( FeedbackCreate( feedback_id="fb-2", run_id="run-1", thread_id="thread-1", rating=-1, user_id="bob", ) ) await session.commit() async with persistence.session_factory() as session: repo = build_feedback_repository(session) assert await repo.get_feedback(first.feedback_id) == first assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [ second.feedback_id, first.feedback_id, ] assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == { "fb-1", "fb-2", } assert await repo.delete_feedback("fb-1") is True assert await repo.delete_feedback("missing") is False with pytest.raises(ValueError, match="rating must be"): await repo.create_feedback( FeedbackCreate( feedback_id="fb-bad", run_id="run-1", thread_id="thread-1", rating=0, ) ) await session.commit() async with persistence.session_factory() as session: repo = build_feedback_repository(session) assert await repo.get_feedback("fb-1") is None finally: await persistence.aclose() @pytest.mark.anyio async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path): persistence = await _make_persistence(tmp_path) try: async with persistence.session_factory() as session: repo = build_run_event_repository(session) rows = await repo.append_batch( [ RunEventCreate( thread_id="thread-1", run_id="run-1", user_id="alice", event_type="message", category="message", content={"role": "user", "content": "hello"}, metadata={"source": "input"}, ), RunEventCreate( thread_id="thread-1", run_id="run-1", event_type="tool", category="debug", content="tool-call", ), RunEventCreate( thread_id="thread-1", run_id="run-2", event_type="message", category="message", content="second", ), RunEventCreate( thread_id="thread-2", run_id="run-3", event_type="message", category="message", content="other-thread", ), ] ) await session.commit() assert [(row.thread_id, row.seq) for row in rows] == [ ("thread-1", 1), ("thread-1", 2), ("thread-1", 3), ("thread-2", 1), ] assert rows[0].content == {"role": "user", "content": "hello"} assert rows[0].metadata == {"source": "input", "content_is_json": True} async with persistence.session_factory() as session: repo = build_run_event_repository(session) messages = await repo.list_messages("thread-1", limit=2) assert [event.seq for event in messages] == [1, 3] assert await repo.count_messages("thread-1") == 2 after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5) assert [event.seq for event in after] == [1] before = await repo.list_messages("thread-1", before_seq=3, limit=5) assert [event.seq for event in before] == [1] events = await repo.list_events("thread-1", "run-1", event_types=["tool"]) assert [event.content for event in events] == ["tool-call"] assert await repo.delete_by_run("thread-1", "run-1") == 2 assert await repo.delete_by_thread("thread-2") == 1 await session.commit() async with persistence.session_factory() as session: repo = build_run_event_repository(session) remaining = await repo.list_events("thread-1", "run-2") assert [event.seq for event in remaining] == [3] assert await repo.count_messages("thread-2") == 0 finally: await persistence.aclose()