mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 07:56:48 +00:00
feat(persistence): add ORM models, repositories, DB/JSONL event stores, RunJournal, and API endpoints
Phase 2-B: run persistence + event storage + token tracking. - ORM models: RunRow (with token fields), ThreadMetaRow, RunEventRow - RunRepository implements RunStore ABC via SQLAlchemy ORM - ThreadMetaRepository with owner access control - DbRunEventStore with trace content truncation and cursor pagination - JsonlRunEventStore with per-run files and seq recovery from disk - RunJournal (BaseCallbackHandler) captures LLM/tool/lifecycle events, accumulates token usage by caller type, buffers and flushes to store - RunManager now accepts optional RunStore for persistent backing - Worker creates RunJournal, writes human_message, injects callbacks - Gateway deps use factory functions (RunRepository when DB available) - New endpoints: messages, run messages, run events, token-usage - ThreadCreateRequest gains assistant_id field - 92 tests pass (33 new), zero regressions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,154 @@
|
||||
"""Phase 2-B integration tests.
|
||||
|
||||
End-to-end test: simulate a run's complete lifecycle, verify data
|
||||
is correctly written to both RunStore and RunEventStore.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
|
||||
|
||||
def _make_llm_response(content="Hello", usage=None):
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.tool_calls = []
|
||||
msg.response_metadata = {"model_name": "test-model"}
|
||||
msg.usage_metadata = usage
|
||||
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
|
||||
response = MagicMock()
|
||||
response.generations = [[gen]]
|
||||
return response
|
||||
|
||||
|
||||
class TestRunLifecycle:
|
||||
@pytest.mark.anyio
|
||||
async def test_full_run_lifecycle(self):
|
||||
"""Simulate a complete run lifecycle with RunStore + RunEventStore."""
|
||||
run_store = MemoryRunStore()
|
||||
event_store = MemoryRunEventStore()
|
||||
|
||||
# 1. Create run
|
||||
await run_store.put("r1", thread_id="t1", status="pending")
|
||||
|
||||
# 2. Write human_message
|
||||
await event_store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content="What is AI?",
|
||||
)
|
||||
|
||||
# 3. Simulate RunJournal callback sequence
|
||||
on_complete_data = {}
|
||||
|
||||
def on_complete(**data):
|
||||
on_complete_data.update(data)
|
||||
|
||||
journal = RunJournal("r1", "t1", event_store, on_complete=on_complete, flush_threshold=100)
|
||||
journal.set_first_human_message("What is AI?")
|
||||
|
||||
# chain_start (top-level)
|
||||
journal.on_chain_start({}, {"messages": ["What is AI?"]}, run_id=uuid4(), parent_run_id=None)
|
||||
|
||||
# llm_start + llm_end
|
||||
llm_run_id = uuid4()
|
||||
journal.on_llm_start({"name": "gpt-4"}, ["prompt"], run_id=llm_run_id, tags=["lead_agent"])
|
||||
usage = {"input_tokens": 50, "output_tokens": 100, "total_tokens": 150}
|
||||
journal.on_llm_end(_make_llm_response("AI is artificial intelligence.", usage=usage), run_id=llm_run_id, tags=["lead_agent"])
|
||||
|
||||
# chain_end (triggers on_complete + flush_sync which creates a task)
|
||||
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
await journal.flush()
|
||||
# Let event loop process any pending flush tasks from _flush_sync
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
# 4. Verify messages
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 2 # human + ai
|
||||
assert messages[0]["event_type"] == "human_message"
|
||||
assert messages[1]["event_type"] == "ai_message"
|
||||
assert messages[1]["content"] == "AI is artificial intelligence."
|
||||
|
||||
# 5. Verify events
|
||||
events = await event_store.list_events("t1", "r1")
|
||||
event_types = {e["event_type"] for e in events}
|
||||
assert "run_start" in event_types
|
||||
assert "llm_start" in event_types
|
||||
assert "llm_end" in event_types
|
||||
assert "run_end" in event_types
|
||||
|
||||
# 6. Verify on_complete data
|
||||
assert on_complete_data["total_tokens"] == 150
|
||||
assert on_complete_data["llm_call_count"] == 1
|
||||
assert on_complete_data["lead_agent_tokens"] == 150
|
||||
assert on_complete_data["message_count"] == 1
|
||||
assert on_complete_data["last_ai_message"] == "AI is artificial intelligence."
|
||||
assert on_complete_data["first_human_message"] == "What is AI?"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_run_with_tool_calls(self):
|
||||
"""Simulate a run that uses tools."""
|
||||
event_store = MemoryRunEventStore()
|
||||
journal = RunJournal("r1", "t1", event_store, flush_threshold=100)
|
||||
|
||||
# tool_start + tool_end
|
||||
journal.on_tool_start({"name": "web_search"}, '{"query": "AI"}', run_id=uuid4())
|
||||
journal.on_tool_end("Search results...", run_id=uuid4(), name="web_search")
|
||||
await journal.flush()
|
||||
|
||||
events = await event_store.list_events("t1", "r1")
|
||||
assert len(events) == 2
|
||||
assert events[0]["event_type"] == "tool_start"
|
||||
assert events[1]["event_type"] == "tool_end"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_multi_run_thread(self):
|
||||
"""Multiple runs on the same thread maintain unified seq ordering."""
|
||||
event_store = MemoryRunEventStore()
|
||||
|
||||
# Run 1
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
|
||||
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
|
||||
|
||||
# Run 2
|
||||
await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2")
|
||||
await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2")
|
||||
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 4
|
||||
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
assert messages[2]["run_id"] == "r2"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_runmanager_with_store_backing(self):
|
||||
"""RunManager persists to RunStore when one is provided."""
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
|
||||
run_store = MemoryRunStore()
|
||||
mgr = RunManager(store=run_store)
|
||||
|
||||
record = await mgr.create("t1", assistant_id="lead_agent")
|
||||
# Verify persisted to store
|
||||
row = await run_store.get(record.run_id)
|
||||
assert row is not None
|
||||
assert row["thread_id"] == "t1"
|
||||
assert row["status"] == "pending"
|
||||
|
||||
# Status update
|
||||
from deerflow.runtime.runs.schemas import RunStatus
|
||||
|
||||
await mgr.set_status(record.run_id, RunStatus.running)
|
||||
row = await run_store.get(record.run_id)
|
||||
assert row["status"] == "running"
|
||||
@@ -1,14 +1,7 @@
|
||||
"""Tests for RunEventStore ABC + MemoryRunEventStore.
|
||||
"""Tests for RunEventStore contract across all backends.
|
||||
|
||||
Covers:
|
||||
- Basic write and query (put, seq assignment, cross-thread independence)
|
||||
- list_messages (category filtering, pagination, cross-run ordering)
|
||||
- list_events (run filtering, event_types filtering)
|
||||
- list_messages_by_run
|
||||
- count_messages
|
||||
- put_batch
|
||||
- delete_by_thread, delete_by_run
|
||||
- Edge cases (empty thread/run)
|
||||
Uses a helper to create the store for each backend type.
|
||||
Memory tests run directly; DB and JSONL tests create stores inside each test.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
@@ -35,7 +28,6 @@ class TestPutAndSeq:
|
||||
assert record["event_type"] == "human_message"
|
||||
assert record["category"] == "message"
|
||||
assert record["content"] == "hello"
|
||||
assert record["metadata"] == {}
|
||||
assert "created_at" in record
|
||||
|
||||
@pytest.mark.anyio
|
||||
@@ -91,7 +83,6 @@ class TestListMessages:
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_before_seq_pagination(self, store):
|
||||
# Put 10 messages with seq 1..10
|
||||
for i in range(10):
|
||||
await store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
messages = await store.list_messages("t1", before_seq=6, limit=3)
|
||||
@@ -236,7 +227,6 @@ class TestDelete:
|
||||
await store.put(thread_id="t1", run_id="r2", event_type="llm_end", category="trace")
|
||||
count = await store.delete_by_run("t1", "r2")
|
||||
assert count == 2
|
||||
# r1 events should still be there
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["run_id"] == "r1"
|
||||
@@ -270,3 +260,145 @@ class TestEdgeCases:
|
||||
@pytest.mark.anyio
|
||||
async def test_empty_thread_count_messages(self, store):
|
||||
assert await store.count_messages("empty") == 0
|
||||
|
||||
|
||||
# -- DB-specific tests --
|
||||
|
||||
|
||||
class TestDbRunEventStore:
|
||||
"""Tests for DbRunEventStore with temp SQLite."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
r2 = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="hello")
|
||||
assert r2["seq"] == 2
|
||||
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
|
||||
count = await s.count_messages("t1")
|
||||
assert count == 2
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_trace_content_truncation(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory(), max_trace_content=100)
|
||||
|
||||
long = "x" * 200
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="llm_end", category="trace", content=long)
|
||||
assert len(r["content"]) == 100
|
||||
assert r["metadata"].get("content_truncated") is True
|
||||
|
||||
# message content NOT truncated
|
||||
m = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=long)
|
||||
assert len(m["content"]) == 200
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_pagination(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
for i in range(10):
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content=str(i))
|
||||
|
||||
# before_seq
|
||||
msgs = await s.list_messages("t1", before_seq=6, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [3, 4, 5]
|
||||
|
||||
# after_seq
|
||||
msgs = await s.list_messages("t1", after_seq=7, limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
# default (latest)
|
||||
msgs = await s.list_messages("t1", limit=3)
|
||||
assert [m["seq"] for m in msgs] == [8, 9, 10]
|
||||
|
||||
await close_engine()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
s = DbRunEventStore(get_session_factory())
|
||||
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 1
|
||||
|
||||
c = await s.delete_by_thread("t1")
|
||||
assert c == 1
|
||||
assert await s.count_messages("t1") == 0
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
# -- JSONL-specific tests --
|
||||
|
||||
|
||||
class TestJsonlRunEventStore:
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
r = await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="hi")
|
||||
assert r["seq"] == 1
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_file_at_correct_path(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
assert (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r1.jsonl").exists()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_cross_run_messages(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
messages = await s.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert [m["seq"] for m in messages] == [1, 2]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_by_run(self, tmp_path):
|
||||
from deerflow.runtime.events.store.jsonl import JsonlRunEventStore
|
||||
|
||||
s = JsonlRunEventStore(base_dir=tmp_path / "jsonl")
|
||||
await s.put(thread_id="t1", run_id="r1", event_type="human_message", category="message")
|
||||
await s.put(thread_id="t1", run_id="r2", event_type="human_message", category="message")
|
||||
c = await s.delete_by_run("t1", "r2")
|
||||
assert c == 1
|
||||
assert not (tmp_path / "jsonl" / "threads" / "t1" / "runs" / "r2.jsonl").exists()
|
||||
assert await s.count_messages("t1") == 1
|
||||
|
||||
@@ -0,0 +1,230 @@
|
||||
"""Tests for RunJournal callback handler.
|
||||
|
||||
Uses MemoryRunEventStore as the backend for direct event inspection.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.runtime.events.store.memory import MemoryRunEventStore
|
||||
from deerflow.runtime.journal import RunJournal
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def journal_setup():
|
||||
store = MemoryRunEventStore()
|
||||
on_complete_data = {}
|
||||
|
||||
def on_complete(**data):
|
||||
on_complete_data.update(data)
|
||||
|
||||
j = RunJournal("r1", "t1", store, on_complete=on_complete, flush_threshold=100)
|
||||
return j, store, on_complete_data
|
||||
|
||||
|
||||
def _make_llm_response(content="Hello", usage=None):
|
||||
"""Create a mock LLM response with a message."""
|
||||
msg = MagicMock()
|
||||
msg.content = content
|
||||
msg.tool_calls = []
|
||||
msg.response_metadata = {"model_name": "test-model"}
|
||||
msg.usage_metadata = usage
|
||||
|
||||
gen = MagicMock()
|
||||
gen.message = msg
|
||||
|
||||
response = MagicMock()
|
||||
response.generations = [[gen]]
|
||||
return response
|
||||
|
||||
|
||||
class TestLlmCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_produces_trace_event(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
trace_events = [e for e in events if e["event_type"] == "llm_end"]
|
||||
assert len(trace_events) == 1
|
||||
assert trace_events[0]["category"] == "trace"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "ai_message"
|
||||
assert messages[0]["content"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
|
||||
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"])
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_token_accumulation(self, journal_setup):
|
||||
j, store, on_complete_data = journal_setup
|
||||
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"])
|
||||
assert j._total_input_tokens == 30
|
||||
assert j._total_output_tokens == 15
|
||||
assert j._total_tokens == 45
|
||||
assert j._llm_call_count == 2
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_caller_token_classification(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["subagent:research"])
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"])
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["middleware:summarization"])
|
||||
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"])
|
||||
assert j._lead_agent_tokens == 15
|
||||
assert j._subagent_tokens == 15
|
||||
assert j._middleware_tokens == 15
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_usage_metadata_none_no_crash(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"])
|
||||
# Should not raise
|
||||
await j.flush()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_latency_tracking(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"])
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
llm_end = [e for e in events if e["event_type"] == "llm_end"][0]
|
||||
assert "latency_ms" in llm_end["metadata"]
|
||||
assert llm_end["metadata"]["latency_ms"] is not None
|
||||
|
||||
|
||||
class TestLifecycleCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_on_chain_end_triggers_on_complete(self, journal_setup):
|
||||
j, store, on_complete_data = journal_setup
|
||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
|
||||
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
|
||||
assert "total_tokens" in on_complete_data
|
||||
assert "message_count" in on_complete_data
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_nested_chain_ignored(self, journal_setup):
|
||||
j, store, on_complete_data = journal_setup
|
||||
parent_id = uuid4()
|
||||
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
|
||||
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
lifecycle = [e for e in events if e["category"] == "lifecycle"]
|
||||
assert len(lifecycle) == 0
|
||||
|
||||
|
||||
class TestToolCallbacks:
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_start_end_produce_trace(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4())
|
||||
j.on_tool_end("results", run_id=uuid4(), name="web_search")
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
types = {e["event_type"] for e in events}
|
||||
assert "tool_start" in types
|
||||
assert "tool_end" in types
|
||||
|
||||
|
||||
class TestCustomEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_summarization_event(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j.on_custom_event(
|
||||
"summarization",
|
||||
{"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]},
|
||||
run_id=uuid4(),
|
||||
)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
trace = [e for e in events if e["event_type"] == "summarization"]
|
||||
assert len(trace) == 1
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "summary"
|
||||
|
||||
|
||||
class TestBufferFlush:
|
||||
@pytest.mark.anyio
|
||||
async def test_flush_threshold(self, journal_setup):
|
||||
j, store, _ = journal_setup
|
||||
j._flush_threshold = 3
|
||||
j.on_tool_start({"name": "a"}, "x", run_id=uuid4())
|
||||
j.on_tool_start({"name": "b"}, "x", run_id=uuid4())
|
||||
# Buffer has 2 events, not yet flushed
|
||||
assert len(j._buffer) == 2
|
||||
j.on_tool_start({"name": "c"}, "x", run_id=uuid4())
|
||||
# Buffer should have been flushed (threshold=3 triggers flush)
|
||||
# Give the async task a chance to complete
|
||||
await asyncio.sleep(0.1)
|
||||
events = await store.list_events("t1", "r1")
|
||||
assert len(events) >= 3
|
||||
|
||||
|
||||
class TestIdentifyCaller:
|
||||
def test_lead_agent_tag(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent"
|
||||
|
||||
def test_subagent_tag(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research"
|
||||
|
||||
def test_middleware_tag(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization"
|
||||
|
||||
def test_no_tags_returns_unknown(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
assert j._identify_caller({"tags": []}) == "unknown"
|
||||
assert j._identify_caller({}) == "unknown"
|
||||
|
||||
|
||||
class TestPublicMethods:
|
||||
@pytest.mark.anyio
|
||||
async def test_set_first_human_message(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
j.set_first_human_message("Hello world")
|
||||
assert j._first_human_msg == "Hello world"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_completion_data(self, journal_setup):
|
||||
j, _, _ = journal_setup
|
||||
j._total_tokens = 100
|
||||
j._msg_count = 5
|
||||
data = j.get_completion_data()
|
||||
assert data["total_tokens"] == 100
|
||||
assert data["message_count"] == 5
|
||||
@@ -0,0 +1,155 @@
|
||||
"""Tests for RunRepository (SQLAlchemy-backed RunStore).
|
||||
|
||||
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
return RunRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _cleanup():
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestRunRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_put_and_get(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="pending")
|
||||
row = await repo.get("r1")
|
||||
assert row is not None
|
||||
assert row["run_id"] == "r1"
|
||||
assert row["thread_id"] == "t1"
|
||||
assert row["status"] == "pending"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_missing_returns_none(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
assert await repo.get("nope") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.update_status("r1", "running")
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "running"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_status_with_error(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.update_status("r1", "error", error="boom")
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "error"
|
||||
assert row["error"] == "boom"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.put("r2", thread_id="t1")
|
||||
await repo.put("r3", thread_id="t2")
|
||||
rows = await repo.list_by_thread("t1")
|
||||
assert len(rows) == 2
|
||||
assert all(r["thread_id"] == "t1" for r in rows)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_by_thread_owner_filter(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", owner_id="alice")
|
||||
await repo.put("r2", thread_id="t1", owner_id="bob")
|
||||
rows = await repo.list_by_thread("t1", owner_id="alice")
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["owner_id"] == "alice"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1")
|
||||
await repo.delete("r1")
|
||||
assert await repo.get("r1") is None
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_delete_nonexistent_is_noop(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.delete("nope") # should not raise
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_list_pending(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="pending")
|
||||
await repo.put("r2", thread_id="t1", status="running")
|
||||
await repo.put("r3", thread_id="t2", status="pending")
|
||||
pending = await repo.list_pending()
|
||||
assert len(pending) == 2
|
||||
assert all(r["status"] == "pending" for r in pending)
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_run_completion(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", status="running")
|
||||
await repo.update_run_completion(
|
||||
"r1",
|
||||
status="success",
|
||||
total_input_tokens=100,
|
||||
total_output_tokens=50,
|
||||
total_tokens=150,
|
||||
llm_call_count=2,
|
||||
lead_agent_tokens=120,
|
||||
subagent_tokens=20,
|
||||
middleware_tokens=10,
|
||||
message_count=3,
|
||||
last_ai_message="The answer is 42",
|
||||
first_human_message="What is the meaning?",
|
||||
)
|
||||
row = await repo.get("r1")
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 150
|
||||
assert row["llm_call_count"] == 2
|
||||
assert row["lead_agent_tokens"] == 120
|
||||
assert row["message_count"] == 3
|
||||
assert row["last_ai_message"] == "The answer is 42"
|
||||
assert row["first_human_message"] == "What is the meaning?"
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_metadata_preserved(self, tmp_path):
|
||||
repo = await _make_repo(tmp_path)
|
||||
await repo.put("r1", thread_id="t1", metadata={"key": "value"})
|
||||
row = await repo.get("r1")
|
||||
assert row["metadata"] == {"key": "value"}
|
||||
await _cleanup()
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_kwargs_with_non_serializable(self, tmp_path):
|
||||
"""kwargs containing non-JSON-serializable objects should be safely handled."""
|
||||
repo = await _make_repo(tmp_path)
|
||||
|
||||
class Dummy:
|
||||
pass
|
||||
|
||||
await repo.put("r1", thread_id="t1", kwargs={"obj": Dummy()})
|
||||
row = await repo.get("r1")
|
||||
assert "obj" in row["kwargs"]
|
||||
await _cleanup()
|
||||
Reference in New Issue
Block a user