refactor(journal): fix flush, token tracking, and consolidate tests

RunJournal fixes:
- _flush_sync: retain events in buffer when no event loop instead of
  dropping them; worker's finally block flushes via async flush().
- on_llm_end: add tool_calls filter and caller=="lead_agent" guard for
  ai_message events; mark message IDs for dedup with record_llm_usage.
- worker.py: persist completion data (tokens, message count) to RunStore
  in finally block.

Model factory:
- Auto-inject stream_usage=True for BaseChatOpenAI subclasses with
  custom api_base, so usage_metadata is populated in streaming responses.

Test consolidation:
- Delete test_phase2b_integration.py (redundant with existing tests).
- Move DB-backed lifecycle test into test_run_journal.py.
- Add tests for stream_usage injection in test_model_factory.py.
- Clean up executor/task_tool dead journal references.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-03 17:26:11 +08:00
parent e5b01d7e74
commit b92ddafd4b
7 changed files with 360 additions and 451 deletions
+209 -106
View File
@@ -16,22 +16,28 @@ from deerflow.runtime.journal import RunJournal
@pytest.fixture
def journal_setup():
store = MemoryRunEventStore()
on_complete_data = {}
def on_complete(**data):
on_complete_data.update(data)
j = RunJournal("r1", "t1", store, on_complete=on_complete, flush_threshold=100)
return j, store, on_complete_data
j = RunJournal("r1", "t1", store, flush_threshold=100)
return j, store
def _make_llm_response(content="Hello", usage=None):
def _make_llm_response(content="Hello", usage=None, tool_calls=None):
"""Create a mock LLM response with a message."""
msg = MagicMock()
msg.content = content
msg.tool_calls = []
msg.id = f"msg-{id(msg)}"
msg.tool_calls = tool_calls or []
msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage
# Provide a real model_dump so serialize_lc_object returns a plain dict
# (needed for DB-backed tests where json.dumps must succeed).
msg.model_dump.return_value = {
"type": "ai",
"content": content,
"id": msg.id,
"tool_calls": tool_calls or [],
"usage_metadata": usage,
"response_metadata": {"model_name": "test-model"},
}
gen = MagicMock()
gen.message = msg
@@ -44,7 +50,7 @@ def _make_llm_response(content="Hello", usage=None):
class TestLlmCallbacks:
@pytest.mark.anyio
async def test_on_llm_end_produces_trace_event(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"])
@@ -56,7 +62,7 @@ class TestLlmCallbacks:
@pytest.mark.anyio
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"])
@@ -66,9 +72,23 @@ class TestLlmCallbacks:
assert messages[0]["event_type"] == "ai_message"
assert messages[0]["content"] == "Answer"
@pytest.mark.anyio
async def test_on_llm_end_with_tool_calls_no_ai_message(self, journal_setup):
"""LLM response with pending tool_calls should NOT produce ai_message."""
j, store = journal_setup
run_id = uuid4()
j.on_llm_end(
_make_llm_response("Let me search", tool_calls=[{"name": "search"}]),
run_id=run_id,
tags=["lead_agent"],
)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
@pytest.mark.anyio
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"])
@@ -78,27 +98,34 @@ class TestLlmCallbacks:
@pytest.mark.anyio
async def test_token_accumulation(self, journal_setup):
j, store, on_complete_data = journal_setup
j, store = journal_setup
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"])
assert j._total_input_tokens == 30
assert j._total_output_tokens == 15
assert j._total_tokens == 45
assert j._llm_call_count == 2
@pytest.mark.anyio
async def test_total_tokens_computed_from_input_output(self, journal_setup):
"""If total_tokens is 0, it should be computed from input + output."""
j, store = journal_setup
j.on_llm_end(
_make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}),
run_id=uuid4(),
tags=["lead_agent"],
)
assert j._total_tokens == 150
assert j._lead_agent_tokens == 150
@pytest.mark.anyio
async def test_caller_token_classification(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["subagent:research"])
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["middleware:summarization"])
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"])
assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 15
@@ -106,15 +133,13 @@ class TestLlmCallbacks:
@pytest.mark.anyio
async def test_usage_metadata_none_no_crash(self, journal_setup):
j, store, _ = journal_setup
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j, store = journal_setup
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"])
# Should not raise
await j.flush()
@pytest.mark.anyio
async def test_latency_tracking(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"])
@@ -127,16 +152,20 @@ class TestLlmCallbacks:
class TestLifecycleCallbacks:
@pytest.mark.anyio
async def test_on_chain_end_triggers_on_complete(self, journal_setup):
j, store, on_complete_data = journal_setup
async def test_chain_start_end_produce_lifecycle_events(self, journal_setup):
j, store = journal_setup
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert "total_tokens" in on_complete_data
assert "message_count" in on_complete_data
await asyncio.sleep(0.05)
await j.flush()
events = await store.list_events("t1", "r1")
types = [e["event_type"] for e in events if e["category"] == "lifecycle"]
assert "run_start" in types
assert "run_end" in types
@pytest.mark.anyio
async def test_nested_chain_ignored(self, journal_setup):
j, store, on_complete_data = journal_setup
j, store = journal_setup
parent_id = uuid4()
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
@@ -149,7 +178,7 @@ class TestLifecycleCallbacks:
class TestToolCallbacks:
@pytest.mark.anyio
async def test_tool_start_end_produce_trace(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4())
j.on_tool_end("results", run_id=uuid4(), name="web_search")
await j.flush()
@@ -158,11 +187,19 @@ class TestToolCallbacks:
assert "tool_start" in types
assert "tool_end" in types
@pytest.mark.anyio
async def test_on_tool_error(self, journal_setup):
j, store = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "tool_error" for e in events)
class TestCustomEvents:
@pytest.mark.anyio
async def test_summarization_event(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
j.on_custom_event(
"summarization",
{"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]},
@@ -176,50 +213,76 @@ class TestCustomEvents:
assert len(messages) == 1
assert messages[0]["event_type"] == "summary"
@pytest.mark.anyio
async def test_non_summarization_custom_event(self, journal_setup):
j, store = journal_setup
j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "task_running" for e in events)
class TestBufferFlush:
@pytest.mark.anyio
async def test_flush_threshold(self, journal_setup):
j, store, _ = journal_setup
j, store = journal_setup
j._flush_threshold = 3
j.on_tool_start({"name": "a"}, "x", run_id=uuid4())
j.on_tool_start({"name": "b"}, "x", run_id=uuid4())
# Buffer has 2 events, not yet flushed
assert len(j._buffer) == 2
j.on_tool_start({"name": "c"}, "x", run_id=uuid4())
# Buffer should have been flushed (threshold=3 triggers flush)
# Give the async task a chance to complete
await asyncio.sleep(0.1)
events = await store.list_events("t1", "r1")
assert len(events) >= 3
@pytest.mark.anyio
async def test_events_retained_when_no_loop(self, journal_setup):
"""Events buffered in a sync (no-loop) context should survive
until the async flush() in the finally block."""
j, store = journal_setup
j._flush_threshold = 1
original = asyncio.get_running_loop
def no_loop():
raise RuntimeError("no running event loop")
asyncio.get_running_loop = no_loop
try:
j._put(event_type="llm_end", category="trace", content="test")
finally:
asyncio.get_running_loop = original
assert len(j._buffer) == 1
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "llm_end" for e in events)
class TestIdentifyCaller:
def test_lead_agent_tag(self, journal_setup):
j, _, _ = journal_setup
j, _ = journal_setup
assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent"
def test_subagent_tag(self, journal_setup):
j, _, _ = journal_setup
j, _ = journal_setup
assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research"
def test_middleware_tag(self, journal_setup):
j, _, _ = journal_setup
j, _ = journal_setup
assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization"
def test_no_tags_returns_unknown(self, journal_setup):
j, _, _ = journal_setup
assert j._identify_caller({"tags": []}) == "unknown"
assert j._identify_caller({}) == "unknown"
def test_no_tags_returns_lead_agent(self, journal_setup):
j, _ = journal_setup
assert j._identify_caller({"tags": []}) == "lead_agent"
assert j._identify_caller({}) == "lead_agent"
class TestChainErrorCallback:
@pytest.mark.anyio
async def test_on_chain_error_writes_run_error(self, journal_setup):
j, store, _ = journal_setup
# parent_run_id must be None (top-level chain) for the event to be recorded
j, store = journal_setup
j.on_chain_error(ValueError("boom"), run_id=uuid4(), parent_run_id=None)
# on_chain_error calls _flush_sync internally, give async task time to complete
await asyncio.sleep(0.05)
await j.flush()
events = await store.list_events("t1", "r1")
@@ -232,85 +295,125 @@ class TestChainErrorCallback:
class TestTokenTrackingDisabled:
@pytest.mark.anyio
async def test_track_token_usage_false(self):
"""track_token_usage=False disables token accumulation."""
store = MemoryRunEventStore()
complete_data = {}
j = RunJournal("r1", "t1", store, track_token_usage=False, on_complete=lambda **d: complete_data.update(d), flush_threshold=100)
j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}), run_id=uuid4(), tags=["lead_agent"])
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert complete_data["total_tokens"] == 0
assert complete_data["llm_call_count"] == 0
class TestMiddlewareNoMessage:
@pytest.mark.anyio
async def test_on_llm_end_middleware_no_ai_message(self, journal_setup):
j, store, _ = journal_setup
j.on_llm_end(_make_llm_response("Summary"), run_id=uuid4(), tags=["middleware:summarization"])
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
class TestUnknownCallerTokens:
@pytest.mark.anyio
async def test_unknown_caller_tokens_go_to_lead(self, journal_setup):
"""No caller tag: tokens attributed to lead_agent bucket."""
j, store, _ = journal_setup
j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=uuid4(), tags=[])
assert j._lead_agent_tokens == 15
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
j.on_llm_end(
_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}),
run_id=uuid4(),
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["total_tokens"] == 0
assert data["llm_call_count"] == 0
class TestConvenienceFields:
@pytest.mark.anyio
async def test_last_ai_message_tracks_latest(self, journal_setup):
j, store, complete_data = journal_setup
j, store = journal_setup
j.on_llm_end(_make_llm_response("First"), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Second"), run_id=uuid4(), tags=["lead_agent"])
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert complete_data["last_ai_message"] == "Second"
assert complete_data["message_count"] == 2
data = j.get_completion_data()
assert data["last_ai_message"] == "Second"
assert data["message_count"] == 2
@pytest.mark.anyio
async def test_first_human_message_via_set(self, journal_setup):
j, store, complete_data = journal_setup
j, _ = journal_setup
j.set_first_human_message("What is AI?")
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert complete_data["first_human_message"] == "What is AI?"
class TestToolError:
@pytest.mark.anyio
async def test_on_tool_error(self, journal_setup):
j, store, _ = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "tool_error" for e in events)
class TestOtherCustomEvent:
@pytest.mark.anyio
async def test_non_summarization_custom_event(self, journal_setup):
j, store, _ = journal_setup
j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "task_running" for e in events)
class TestPublicMethods:
@pytest.mark.anyio
async def test_set_first_human_message(self, journal_setup):
j, _, _ = journal_setup
j.set_first_human_message("Hello world")
assert j._first_human_msg == "Hello world"
data = j.get_completion_data()
assert data["first_human_message"] == "What is AI?"
@pytest.mark.anyio
async def test_get_completion_data(self, journal_setup):
j, _, _ = journal_setup
j, _ = journal_setup
j._total_tokens = 100
j._msg_count = 5
data = j.get_completion_data()
assert data["total_tokens"] == 100
assert data["message_count"] == 5
class TestUnknownCallerTokens:
@pytest.mark.anyio
async def test_unknown_caller_tokens_go_to_lead(self, journal_setup):
j, store = journal_setup
j.on_llm_end(
_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
tags=[],
)
assert j._lead_agent_tokens == 15
# ---------------------------------------------------------------------------
# SQLite-backed end-to-end test
# ---------------------------------------------------------------------------
class TestDbBackedLifecycle:
@pytest.mark.anyio
async def test_full_lifecycle_with_sqlite(self, tmp_path):
"""Full lifecycle with SQLite-backed RunRepository + DbRunEventStore."""
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.persistence.repositories.run_repo import RunRepository
from deerflow.runtime.events.store.db import DbRunEventStore
from deerflow.runtime.runs.manager import RunManager
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
sf = get_session_factory()
run_store = RunRepository(sf)
event_store = DbRunEventStore(sf)
mgr = RunManager(store=run_store)
# Create run
record = await mgr.create("t1", "lead_agent")
run_id = record.run_id
# Write human_message
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB")
# Simulate journal
journal = RunJournal(run_id, "t1", event_store, flush_threshold=100)
journal.set_first_human_message("Hello DB")
journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
llm_rid = uuid4()
journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"])
journal.on_llm_end(
_make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=llm_rid,
tags=["lead_agent"],
)
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
await asyncio.sleep(0.05)
await journal.flush()
# Verify run persisted
row = await run_store.get(run_id)
assert row is not None
assert row["status"] == "pending"
# Update completion
completion = journal.get_completion_data()
await run_store.update_run_completion(run_id, status="success", **completion)
row = await run_store.get(run_id)
assert row["status"] == "success"
assert row["total_tokens"] == 15
# Verify messages from DB
messages = await event_store.list_messages("t1")
assert len(messages) == 2
assert messages[0]["event_type"] == "human_message"
assert messages[1]["event_type"] == "ai_message"
# Verify events from DB
events = await event_store.list_events("t1", run_id)
event_types = {e["event_type"] for e in events}
assert "run_start" in event_types
assert "llm_end" in event_types
assert "run_end" in event_types
await close_engine()