refactor(journal): fix flush, token tracking, and consolidate tests

RunJournal fixes:
- _flush_sync: retain events in buffer when no event loop instead of
  dropping them; worker's finally block flushes via async flush().
- on_llm_end: add tool_calls filter and caller=="lead_agent" guard for
  ai_message events; mark message IDs for dedup with record_llm_usage.
- worker.py: persist completion data (tokens, message count) to RunStore
  in finally block.

Model factory:
- Auto-inject stream_usage=True for BaseChatOpenAI subclasses with
  custom api_base, so usage_metadata is populated in streaming responses.

Test consolidation:
- Delete test_phase2b_integration.py (redundant with existing tests).
- Move DB-backed lifecycle test into test_run_journal.py.
- Add tests for stream_usage injection in test_model_factory.py.
- Clean up executor/task_tool dead journal references.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-03 17:26:11 +08:00
parent e5b01d7e74
commit b92ddafd4b
7 changed files with 360 additions and 451 deletions
@@ -77,6 +77,15 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif "reasoning_effort" not in model_settings_from_config: elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium" model_settings_from_config["reasoning_effort"] = "medium"
# Ensure stream_usage is enabled so that token usage metadata is available
# in streaming responses. LangChain's BaseChatOpenAI only defaults
# stream_usage=True when no custom base_url/api_base is set, so models
# hitting third-party endpoints (e.g. doubao, deepseek) silently lose
# usage data. We default it to True unless explicitly configured.
if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs:
if "stream_usage" in getattr(model_class, "model_fields", {}):
model_settings_from_config["stream_usage"] = True
model_instance = model_class(**kwargs, **model_settings_from_config) model_instance = model_class(**kwargs, **model_settings_from_config)
if is_tracing_enabled(): if is_tracing_enabled():
@@ -16,7 +16,6 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time import time
from collections.abc import Callable
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from uuid import UUID from uuid import UUID
@@ -39,7 +38,6 @@ class RunJournal(BaseCallbackHandler):
event_store: RunEventStore, event_store: RunEventStore,
*, *,
track_token_usage: bool = True, track_token_usage: bool = True,
on_complete: Callable[..., Any] | None = None,
flush_threshold: int = 20, flush_threshold: int = 20,
): ):
super().__init__() super().__init__()
@@ -47,7 +45,6 @@ class RunJournal(BaseCallbackHandler):
self.thread_id = thread_id self.thread_id = thread_id
self._store = event_store self._store = event_store
self._track_tokens = track_token_usage self._track_tokens = track_token_usage
self._on_complete = on_complete
self._flush_threshold = flush_threshold self._flush_threshold = flush_threshold
# Write buffer # Write buffer
@@ -73,7 +70,6 @@ class RunJournal(BaseCallbackHandler):
# -- Lifecycle callbacks -- # -- Lifecycle callbacks --
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None: def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
# Only record for the top-level chain (parent_run_id is None)
if kwargs.get("parent_run_id") is not None: if kwargs.get("parent_run_id") is not None:
return return
self._put( self._put(
@@ -87,19 +83,6 @@ class RunJournal(BaseCallbackHandler):
return return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"}) self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync() self._flush_sync()
if self._on_complete:
self._on_complete(
total_input_tokens=self._total_input_tokens,
total_output_tokens=self._total_output_tokens,
total_tokens=self._total_tokens,
llm_call_count=self._llm_call_count,
lead_agent_tokens=self._lead_agent_tokens,
subagent_tokens=self._subagent_tokens,
middleware_tokens=self._middleware_tokens,
message_count=self._msg_count,
last_ai_message=self._last_ai_msg,
first_human_message=self._first_human_msg,
)
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None: if kwargs.get("parent_run_id") is not None:
@@ -131,7 +114,6 @@ class RunJournal(BaseCallbackHandler):
logger.debug("on_llm_end: could not extract message from response") logger.debug("on_llm_end: could not extract message from response")
return return
serialized_msg = serialize_lc_object(message)
caller = self._identify_caller(kwargs) caller = self._identify_caller(kwargs)
# Latency # Latency
@@ -142,54 +124,52 @@ class RunJournal(BaseCallbackHandler):
usage = getattr(message, "usage_metadata", None) usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {} usage_dict = dict(usage) if usage else {}
# trace event: llm_end (every LLM call) # Trace event: llm_end (every LLM call)
content = getattr(message, "content", "")
self._put( self._put(
event_type="llm_end", event_type="llm_end",
category="trace", category="trace",
content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")), content=content if isinstance(content, str) else str(content),
metadata={ metadata={
"message": serialized_msg, "message": serialize_lc_object(message),
"caller": caller, "caller": caller,
"usage": usage_dict, "usage": usage_dict,
"latency_ms": latency_ms, "latency_ms": latency_ms,
}, },
) )
# message event: ai_message (only lead_agent final replies with content) # Message event: ai_message (only lead_agent final replies — no pending tool_calls)
if caller == "lead_agent": tool_calls = getattr(message, "tool_calls", None) or []
content = getattr(message, "content", "") if caller == "lead_agent" and isinstance(content, str) and content and not tool_calls:
if isinstance(content, str) and content: resp_meta = getattr(message, "response_metadata", None) or {}
tool_calls = getattr(message, "tool_calls", None) or [] model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)] self._put(
resp_meta = getattr(message, "response_metadata", None) or {} event_type="ai_message",
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None category="message",
self._put( content=content,
event_type="ai_message", metadata={"model_name": model_name},
category="message", )
content=content, self._last_ai_msg = content[:2000]
metadata={ self._msg_count += 1
"model_name": model_name,
"tool_calls": tool_calls_summary,
},
)
self._last_ai_msg = content[:2000]
self._msg_count += 1
# Token accumulation # Token accumulation
input_tk = usage_dict.get("input_tokens", 0) or 0 if self._track_tokens:
output_tk = usage_dict.get("output_tokens", 0) or 0 input_tk = usage_dict.get("input_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0 output_tk = usage_dict.get("output_tokens", 0) or 0
if self._track_tokens and total_tk > 0: total_tk = usage_dict.get("total_tokens", 0) or 0
self._total_input_tokens += input_tk if total_tk == 0:
self._total_output_tokens += output_tk total_tk = input_tk + output_tk
self._total_tokens += total_tk if total_tk > 0:
self._llm_call_count += 1 self._total_input_tokens += input_tk
if caller.startswith("subagent:"): self._total_output_tokens += output_tk
self._subagent_tokens += total_tk self._total_tokens += total_tk
elif caller.startswith("middleware:"): self._llm_call_count += 1
self._middleware_tokens += total_tk if caller.startswith("subagent:"):
else: self._subagent_tokens += total_tk
self._lead_agent_tokens += total_tk elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None) self._llm_start_times.pop(str(run_id), None)
@@ -277,20 +257,23 @@ class RunJournal(BaseCallbackHandler):
self._flush_sync() self._flush_sync()
def _flush_sync(self) -> None: def _flush_sync(self) -> None:
"""Flush buffer to RunEventStore. """Best-effort flush of buffer to RunEventStore.
BaseCallbackHandler methods are synchronous. We schedule the async BaseCallbackHandler methods are synchronous. If an event loop is
put_batch via the current event loop. running we schedule an async ``put_batch``; otherwise the events
stay in the buffer and are flushed later by the async ``flush()``
call in the worker's ``finally`` block.
""" """
if not self._buffer: if not self._buffer:
return return
batch = self._buffer.copy()
self._buffer.clear()
try: try:
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
loop.create_task(self._flush_async(batch))
except RuntimeError: except RuntimeError:
logger.warning("RunJournal: no event loop, dropping %d events", len(batch)) # No event loop — keep events in buffer for later async flush.
return
batch = self._buffer.copy()
self._buffer.clear()
loop.create_task(self._flush_async(batch))
async def _flush_async(self, batch: list[dict]) -> None: async def _flush_async(self, batch: list[dict]) -> None:
try: try:
@@ -302,7 +285,10 @@ class RunJournal(BaseCallbackHandler):
for tag in kwargs.get("tags") or []: for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"): if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag return tag
return "unknown" # Default to lead_agent: the main agent graph does not inject
# callback tags, while subagents and middleware explicitly tag
# themselves.
return "lead_agent"
# -- Public methods (called by worker) -- # -- Public methods (called by worker) --
@@ -311,7 +297,7 @@ class RunJournal(BaseCallbackHandler):
self._first_human_msg = content[:2000] if content else None self._first_human_msg = content[:2000] if content else None
async def flush(self) -> None: async def flush(self) -> None:
"""Force flush. Used in cancel/error paths.""" """Force flush remaining buffer. Called in worker's finally block."""
if self._buffer: if self._buffer:
batch = self._buffer.copy() batch = self._buffer.copy()
self._buffer.clear() self._buffer.clear()
@@ -123,7 +123,8 @@ async def run_agent(
runtime = Runtime(context={"thread_id": thread_id}, store=store) runtime = Runtime(context={"thread_id": thread_id}, store=store)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a callback # Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
if journal is not None: if journal is not None:
config.setdefault("callbacks", []).append(journal) config.setdefault("callbacks", []).append(journal)
@@ -241,13 +242,25 @@ async def run_agent(
) )
finally: finally:
# Flush any buffered journal events # Flush any buffered journal events and persist completion data
if journal is not None: if journal is not None:
try: try:
await journal.flush() await journal.flush()
except Exception: except Exception:
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True) logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
# Persist token usage + convenience fields to RunStore
if run_manager._store is not None:
try:
completion = journal.get_completion_data()
await run_manager._store.update_run_completion(
run_id,
status=record.status.value,
**completion,
)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
await bridge.publish_end(run_id) await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60)) asyncio.create_task(bridge.cleanup(run_id, delay=60))
+78
View File
@@ -593,6 +593,84 @@ def test_codex_provider_strips_unsupported_max_tokens(monkeypatch):
assert "max_tokens" not in FakeChatModel.captured_kwargs assert "max_tokens" not in FakeChatModel.captured_kwargs
# ---------------------------------------------------------------------------
# stream_usage injection
# ---------------------------------------------------------------------------
class _FakeWithStreamUsage(FakeChatModel):
"""Fake model that declares stream_usage in model_fields (like BaseChatOpenAI)."""
stream_usage: bool | None = None
def test_stream_usage_injected_for_openai_compatible_model(monkeypatch):
"""Factory should set stream_usage=True for models with stream_usage field."""
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
captured: dict = {}
class CapturingModel(_FakeWithStreamUsage):
def __init__(self, **kwargs):
captured.update(kwargs)
BaseChatModel.__init__(self, **kwargs)
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="deepseek")
assert captured.get("stream_usage") is True
def test_stream_usage_not_injected_for_non_openai_model(monkeypatch):
"""Factory should NOT inject stream_usage for models without the field."""
cfg = _make_app_config([_make_model("claude", use="langchain_anthropic:ChatAnthropic")])
_patch_factory(monkeypatch, cfg)
captured: dict = {}
class CapturingModel(FakeChatModel):
def __init__(self, **kwargs):
captured.update(kwargs)
BaseChatModel.__init__(self, **kwargs)
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
factory_module.create_chat_model(name="claude")
assert "stream_usage" not in captured
def test_stream_usage_not_overridden_when_explicitly_set_in_config(monkeypatch):
"""If config dumps stream_usage=False, factory should respect it."""
cfg = _make_app_config([_make_model("deepseek", use="langchain_deepseek:ChatDeepSeek")])
_patch_factory(monkeypatch, cfg, model_class=_FakeWithStreamUsage)
captured: dict = {}
class CapturingModel(_FakeWithStreamUsage):
def __init__(self, **kwargs):
captured.update(kwargs)
BaseChatModel.__init__(self, **kwargs)
monkeypatch.setattr(factory_module, "resolve_class", lambda path, base: CapturingModel)
# Simulate config having stream_usage explicitly set by patching model_dump
original_get_model_config = cfg.get_model_config
def patched_get_model_config(name):
mc = original_get_model_config(name)
mc.stream_usage = False # type: ignore[attr-defined]
return mc
monkeypatch.setattr(cfg, "get_model_config", patched_get_model_config)
factory_module.create_chat_model(name="deepseek")
assert captured.get("stream_usage") is False
def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch): def test_openai_responses_api_settings_are_passed_to_chatopenai(monkeypatch):
model = ModelConfig( model = ModelConfig(
name="gpt-5-responses", name="gpt-5-responses",
@@ -15,7 +15,6 @@ import pytest
from deerflow.config.database_config import DatabaseConfig from deerflow.config.database_config import DatabaseConfig
from deerflow.runtime.runs.store.memory import MemoryRunStore from deerflow.runtime.runs.store.memory import MemoryRunStore
# -- DatabaseConfig -- # -- DatabaseConfig --
-279
View File
@@ -1,279 +0,0 @@
"""Phase 2-B integration tests.
End-to-end test: simulate a run's complete lifecycle, verify data
is correctly written to both RunStore and RunEventStore.
"""
import asyncio
from uuid import uuid4
import pytest
from deerflow.runtime.events.store.memory import MemoryRunEventStore
from deerflow.runtime.journal import RunJournal
from deerflow.runtime.runs.store.memory import MemoryRunStore
class _FakeMessage:
def __init__(self, content, usage):
self.content = content
self.tool_calls = []
self.response_metadata = {"model_name": "test-model"}
self.usage_metadata = usage
self.id = "test-msg-id"
def model_dump(self):
return {"type": "ai", "content": self.content, "id": self.id, "tool_calls": [], "usage_metadata": self.usage_metadata, "response_metadata": self.response_metadata}
class _FakeGeneration:
def __init__(self, message):
self.message = message
class _FakeLLMResult:
def __init__(self, content, usage):
self.generations = [[_FakeGeneration(_FakeMessage(content, usage))]]
def _make_llm_response(content="Hello", usage=None):
return _FakeLLMResult(content, usage)
class TestRunLifecycle:
@pytest.mark.anyio
async def test_full_run_lifecycle(self):
"""Simulate a complete run lifecycle with RunStore + RunEventStore."""
run_store = MemoryRunStore()
event_store = MemoryRunEventStore()
# 1. Create run
await run_store.put("r1", thread_id="t1", status="pending")
# 2. Write human_message
await event_store.put(
thread_id="t1",
run_id="r1",
event_type="human_message",
category="message",
content="What is AI?",
)
# 3. Simulate RunJournal callback sequence
on_complete_data = {}
def on_complete(**data):
on_complete_data.update(data)
journal = RunJournal("r1", "t1", event_store, on_complete=on_complete, flush_threshold=100)
journal.set_first_human_message("What is AI?")
# chain_start (top-level)
journal.on_chain_start({}, {"messages": ["What is AI?"]}, run_id=uuid4(), parent_run_id=None)
# llm_start + llm_end
llm_run_id = uuid4()
journal.on_llm_start({"name": "gpt-4"}, ["prompt"], run_id=llm_run_id, tags=["lead_agent"])
usage = {"input_tokens": 50, "output_tokens": 100, "total_tokens": 150}
journal.on_llm_end(_make_llm_response("AI is artificial intelligence.", usage=usage), run_id=llm_run_id, tags=["lead_agent"])
# chain_end (triggers on_complete + flush_sync which creates a task)
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
await journal.flush()
# Let event loop process any pending flush tasks from _flush_sync
await asyncio.sleep(0.05)
# 4. Verify messages
messages = await event_store.list_messages("t1")
assert len(messages) == 2 # human + ai
assert messages[0]["event_type"] == "human_message"
assert messages[1]["event_type"] == "ai_message"
assert messages[1]["content"] == "AI is artificial intelligence."
# 5. Verify events
events = await event_store.list_events("t1", "r1")
event_types = {e["event_type"] for e in events}
assert "run_start" in event_types
assert "llm_start" in event_types
assert "llm_end" in event_types
assert "run_end" in event_types
# 6. Verify on_complete data
assert on_complete_data["total_tokens"] == 150
assert on_complete_data["llm_call_count"] == 1
assert on_complete_data["lead_agent_tokens"] == 150
assert on_complete_data["message_count"] == 1
assert on_complete_data["last_ai_message"] == "AI is artificial intelligence."
assert on_complete_data["first_human_message"] == "What is AI?"
@pytest.mark.anyio
async def test_run_with_tool_calls(self):
"""Simulate a run that uses tools."""
event_store = MemoryRunEventStore()
journal = RunJournal("r1", "t1", event_store, flush_threshold=100)
# tool_start + tool_end
journal.on_tool_start({"name": "web_search"}, '{"query": "AI"}', run_id=uuid4())
journal.on_tool_end("Search results...", run_id=uuid4(), name="web_search")
await journal.flush()
events = await event_store.list_events("t1", "r1")
assert len(events) == 2
assert events[0]["event_type"] == "tool_start"
assert events[1]["event_type"] == "tool_end"
@pytest.mark.anyio
async def test_multi_run_thread(self):
"""Multiple runs on the same thread maintain unified seq ordering."""
event_store = MemoryRunEventStore()
# Run 1
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
# Run 2
await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2")
await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2")
messages = await event_store.list_messages("t1")
assert len(messages) == 4
assert [m["seq"] for m in messages] == [1, 2, 3, 4]
assert messages[0]["run_id"] == "r1"
assert messages[2]["run_id"] == "r2"
@pytest.mark.anyio
async def test_runmanager_with_store_backing(self):
"""RunManager persists to RunStore when one is provided."""
from deerflow.runtime.runs.manager import RunManager
run_store = MemoryRunStore()
mgr = RunManager(store=run_store)
record = await mgr.create("t1", assistant_id="lead_agent")
# Verify persisted to store
row = await run_store.get(record.run_id)
assert row is not None
assert row["thread_id"] == "t1"
assert row["status"] == "pending"
# Status update
from deerflow.runtime.runs.schemas import RunStatus
await mgr.set_status(record.run_id, RunStatus.running)
row = await run_store.get(record.run_id)
assert row["status"] == "running"
@pytest.mark.anyio
async def test_runmanager_create_or_reject_persists(self):
"""create_or_reject also persists to store."""
from deerflow.runtime.runs.manager import RunManager
run_store = MemoryRunStore()
mgr = RunManager(store=run_store)
record = await mgr.create_or_reject("t1", "lead_agent", metadata={"key": "val"})
row = await run_store.get(record.run_id)
assert row is not None
assert row["status"] == "pending"
assert row["metadata"] == {"key": "val"}
@pytest.mark.anyio
async def test_follow_up_metadata_in_messages(self):
"""human_message metadata carries follow_up_to_run_id."""
event_store = MemoryRunEventStore()
# Run 1
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
# Run 2 (follow-up)
await event_store.put(
thread_id="t1",
run_id="r2",
event_type="human_message",
category="message",
content="Tell me more",
metadata={"follow_up_to_run_id": "r1"},
)
messages = await event_store.list_messages("t1")
assert len(messages) == 3
assert messages[2]["metadata"]["follow_up_to_run_id"] == "r1"
@pytest.mark.anyio
async def test_summarization_in_history(self):
"""summary message appears correctly in message history."""
event_store = MemoryRunEventStore()
await event_store.put(thread_id="t1", run_id="r1", event_type="human_message", category="message", content="Q1")
await event_store.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content="A1")
await event_store.put(thread_id="t1", run_id="r2", event_type="summary", category="message", content="Previous conversation summarized.", metadata={"replaced_count": 2})
await event_store.put(thread_id="t1", run_id="r2", event_type="human_message", category="message", content="Q2")
await event_store.put(thread_id="t1", run_id="r2", event_type="ai_message", category="message", content="A2")
messages = await event_store.list_messages("t1")
assert len(messages) == 5
assert messages[2]["event_type"] == "summary"
assert messages[2]["metadata"]["replaced_count"] == 2
@pytest.mark.anyio
async def test_db_backed_run_lifecycle(self, tmp_path):
"""Full lifecycle with SQLite-backed RunRepository + DbRunEventStore."""
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.persistence.repositories.run_repo import RunRepository
from deerflow.runtime.events.store.db import DbRunEventStore
from deerflow.runtime.runs.manager import RunManager
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
sf = get_session_factory()
run_store = RunRepository(sf)
event_store = DbRunEventStore(sf)
mgr = RunManager(store=run_store)
# Create run
record = await mgr.create("t1", "lead_agent")
run_id = record.run_id
# Write human_message
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB")
# Simulate journal
on_complete_data = {}
journal = RunJournal(run_id, "t1", event_store, on_complete=lambda **d: on_complete_data.update(d), flush_threshold=100)
journal.set_first_human_message("Hello DB")
journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
llm_rid = uuid4()
journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"])
journal.on_llm_end(_make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=llm_rid, tags=["lead_agent"])
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
await journal.flush()
await asyncio.sleep(0.05)
# Verify run persisted
row = await run_store.get(run_id)
assert row is not None
assert row["status"] == "pending" # RunManager set it, journal doesn't update status
# Update completion
await run_store.update_run_completion(run_id, status="success", **on_complete_data)
row = await run_store.get(run_id)
assert row["status"] == "success"
assert row["total_tokens"] == 15
# Verify messages from DB
messages = await event_store.list_messages("t1")
assert len(messages) == 2
assert messages[0]["event_type"] == "human_message"
assert messages[1]["event_type"] == "ai_message"
# Verify events from DB
events = await event_store.list_events("t1", run_id)
event_types = {e["event_type"] for e in events}
assert "run_start" in event_types
assert "llm_end" in event_types
assert "run_end" in event_types
await close_engine()
+209 -106
View File
@@ -16,22 +16,28 @@ from deerflow.runtime.journal import RunJournal
@pytest.fixture @pytest.fixture
def journal_setup(): def journal_setup():
store = MemoryRunEventStore() store = MemoryRunEventStore()
on_complete_data = {} j = RunJournal("r1", "t1", store, flush_threshold=100)
return j, store
def on_complete(**data):
on_complete_data.update(data)
j = RunJournal("r1", "t1", store, on_complete=on_complete, flush_threshold=100)
return j, store, on_complete_data
def _make_llm_response(content="Hello", usage=None): def _make_llm_response(content="Hello", usage=None, tool_calls=None):
"""Create a mock LLM response with a message.""" """Create a mock LLM response with a message."""
msg = MagicMock() msg = MagicMock()
msg.content = content msg.content = content
msg.tool_calls = [] msg.id = f"msg-{id(msg)}"
msg.tool_calls = tool_calls or []
msg.response_metadata = {"model_name": "test-model"} msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage msg.usage_metadata = usage
# Provide a real model_dump so serialize_lc_object returns a plain dict
# (needed for DB-backed tests where json.dumps must succeed).
msg.model_dump.return_value = {
"type": "ai",
"content": content,
"id": msg.id,
"tool_calls": tool_calls or [],
"usage_metadata": usage,
"response_metadata": {"model_name": "test-model"},
}
gen = MagicMock() gen = MagicMock()
gen.message = msg gen.message = msg
@@ -44,7 +50,7 @@ def _make_llm_response(content="Hello", usage=None):
class TestLlmCallbacks: class TestLlmCallbacks:
@pytest.mark.anyio @pytest.mark.anyio
async def test_on_llm_end_produces_trace_event(self, journal_setup): async def test_on_llm_end_produces_trace_event(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
run_id = uuid4() run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"]) j.on_llm_end(_make_llm_response("Hi"), run_id=run_id, tags=["lead_agent"])
@@ -56,7 +62,7 @@ class TestLlmCallbacks:
@pytest.mark.anyio @pytest.mark.anyio
async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup): async def test_on_llm_end_lead_agent_produces_ai_message(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
run_id = uuid4() run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"]) j.on_llm_end(_make_llm_response("Answer"), run_id=run_id, tags=["lead_agent"])
@@ -66,9 +72,23 @@ class TestLlmCallbacks:
assert messages[0]["event_type"] == "ai_message" assert messages[0]["event_type"] == "ai_message"
assert messages[0]["content"] == "Answer" assert messages[0]["content"] == "Answer"
@pytest.mark.anyio
async def test_on_llm_end_with_tool_calls_no_ai_message(self, journal_setup):
"""LLM response with pending tool_calls should NOT produce ai_message."""
j, store = journal_setup
run_id = uuid4()
j.on_llm_end(
_make_llm_response("Let me search", tool_calls=[{"name": "search"}]),
run_id=run_id,
tags=["lead_agent"],
)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
@pytest.mark.anyio @pytest.mark.anyio
async def test_on_llm_end_subagent_no_ai_message(self, journal_setup): async def test_on_llm_end_subagent_no_ai_message(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
run_id = uuid4() run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"]) j.on_llm_start({}, [], run_id=run_id, tags=["subagent:research"])
j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"]) j.on_llm_end(_make_llm_response("Sub answer"), run_id=run_id, tags=["subagent:research"])
@@ -78,27 +98,34 @@ class TestLlmCallbacks:
@pytest.mark.anyio @pytest.mark.anyio
async def test_token_accumulation(self, journal_setup): async def test_token_accumulation(self, journal_setup):
j, store, on_complete_data = journal_setup j, store = journal_setup
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30} usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4(), tags=["lead_agent"])
assert j._total_input_tokens == 30 assert j._total_input_tokens == 30
assert j._total_output_tokens == 15 assert j._total_output_tokens == 15
assert j._total_tokens == 45 assert j._total_tokens == 45
assert j._llm_call_count == 2 assert j._llm_call_count == 2
@pytest.mark.anyio
async def test_total_tokens_computed_from_input_output(self, journal_setup):
"""If total_tokens is 0, it should be computed from input + output."""
j, store = journal_setup
j.on_llm_end(
_make_llm_response("Hi", usage={"input_tokens": 100, "output_tokens": 50, "total_tokens": 0}),
run_id=uuid4(),
tags=["lead_agent"],
)
assert j._total_tokens == 150
assert j._lead_agent_tokens == 150
@pytest.mark.anyio @pytest.mark.anyio
async def test_caller_token_classification(self, journal_setup): async def test_caller_token_classification(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["subagent:research"])
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"]) j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), tags=["subagent:research"])
j.on_llm_start({}, [], run_id=uuid4(), tags=["middleware:summarization"])
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"]) j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), tags=["middleware:summarization"])
assert j._lead_agent_tokens == 15 assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 15 assert j._subagent_tokens == 15
@@ -106,15 +133,13 @@ class TestLlmCallbacks:
@pytest.mark.anyio @pytest.mark.anyio
async def test_usage_metadata_none_no_crash(self, journal_setup): async def test_usage_metadata_none_no_crash(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
j.on_llm_start({}, [], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("No usage", usage=None), run_id=uuid4(), tags=["lead_agent"])
# Should not raise
await j.flush() await j.flush()
@pytest.mark.anyio @pytest.mark.anyio
async def test_latency_tracking(self, journal_setup): async def test_latency_tracking(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
run_id = uuid4() run_id = uuid4()
j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"]) j.on_llm_start({}, [], run_id=run_id, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"]) j.on_llm_end(_make_llm_response("Fast"), run_id=run_id, tags=["lead_agent"])
@@ -127,16 +152,20 @@ class TestLlmCallbacks:
class TestLifecycleCallbacks: class TestLifecycleCallbacks:
@pytest.mark.anyio @pytest.mark.anyio
async def test_on_chain_end_triggers_on_complete(self, journal_setup): async def test_chain_start_end_produce_lifecycle_events(self, journal_setup):
j, store, on_complete_data = journal_setup j, store = journal_setup
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None) j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) j.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
assert "total_tokens" in on_complete_data await asyncio.sleep(0.05)
assert "message_count" in on_complete_data await j.flush()
events = await store.list_events("t1", "r1")
types = [e["event_type"] for e in events if e["category"] == "lifecycle"]
assert "run_start" in types
assert "run_end" in types
@pytest.mark.anyio @pytest.mark.anyio
async def test_nested_chain_ignored(self, journal_setup): async def test_nested_chain_ignored(self, journal_setup):
j, store, on_complete_data = journal_setup j, store = journal_setup
parent_id = uuid4() parent_id = uuid4()
j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id) j.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=parent_id)
j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id) j.on_chain_end({}, run_id=uuid4(), parent_run_id=parent_id)
@@ -149,7 +178,7 @@ class TestLifecycleCallbacks:
class TestToolCallbacks: class TestToolCallbacks:
@pytest.mark.anyio @pytest.mark.anyio
async def test_tool_start_end_produce_trace(self, journal_setup): async def test_tool_start_end_produce_trace(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4()) j.on_tool_start({"name": "web_search"}, "query", run_id=uuid4())
j.on_tool_end("results", run_id=uuid4(), name="web_search") j.on_tool_end("results", run_id=uuid4(), name="web_search")
await j.flush() await j.flush()
@@ -158,11 +187,19 @@ class TestToolCallbacks:
assert "tool_start" in types assert "tool_start" in types
assert "tool_end" in types assert "tool_end" in types
@pytest.mark.anyio
async def test_on_tool_error(self, journal_setup):
j, store = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "tool_error" for e in events)
class TestCustomEvents: class TestCustomEvents:
@pytest.mark.anyio @pytest.mark.anyio
async def test_summarization_event(self, journal_setup): async def test_summarization_event(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
j.on_custom_event( j.on_custom_event(
"summarization", "summarization",
{"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]}, {"summary": "Context was summarized.", "replaced_count": 5, "replaced_message_ids": ["a", "b"]},
@@ -176,50 +213,76 @@ class TestCustomEvents:
assert len(messages) == 1 assert len(messages) == 1
assert messages[0]["event_type"] == "summary" assert messages[0]["event_type"] == "summary"
@pytest.mark.anyio
async def test_non_summarization_custom_event(self, journal_setup):
j, store = journal_setup
j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "task_running" for e in events)
class TestBufferFlush: class TestBufferFlush:
@pytest.mark.anyio @pytest.mark.anyio
async def test_flush_threshold(self, journal_setup): async def test_flush_threshold(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
j._flush_threshold = 3 j._flush_threshold = 3
j.on_tool_start({"name": "a"}, "x", run_id=uuid4()) j.on_tool_start({"name": "a"}, "x", run_id=uuid4())
j.on_tool_start({"name": "b"}, "x", run_id=uuid4()) j.on_tool_start({"name": "b"}, "x", run_id=uuid4())
# Buffer has 2 events, not yet flushed
assert len(j._buffer) == 2 assert len(j._buffer) == 2
j.on_tool_start({"name": "c"}, "x", run_id=uuid4()) j.on_tool_start({"name": "c"}, "x", run_id=uuid4())
# Buffer should have been flushed (threshold=3 triggers flush)
# Give the async task a chance to complete
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
events = await store.list_events("t1", "r1") events = await store.list_events("t1", "r1")
assert len(events) >= 3 assert len(events) >= 3
@pytest.mark.anyio
async def test_events_retained_when_no_loop(self, journal_setup):
"""Events buffered in a sync (no-loop) context should survive
until the async flush() in the finally block."""
j, store = journal_setup
j._flush_threshold = 1
original = asyncio.get_running_loop
def no_loop():
raise RuntimeError("no running event loop")
asyncio.get_running_loop = no_loop
try:
j._put(event_type="llm_end", category="trace", content="test")
finally:
asyncio.get_running_loop = original
assert len(j._buffer) == 1
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "llm_end" for e in events)
class TestIdentifyCaller: class TestIdentifyCaller:
def test_lead_agent_tag(self, journal_setup): def test_lead_agent_tag(self, journal_setup):
j, _, _ = journal_setup j, _ = journal_setup
assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent" assert j._identify_caller({"tags": ["lead_agent"]}) == "lead_agent"
def test_subagent_tag(self, journal_setup): def test_subagent_tag(self, journal_setup):
j, _, _ = journal_setup j, _ = journal_setup
assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research" assert j._identify_caller({"tags": ["subagent:research"]}) == "subagent:research"
def test_middleware_tag(self, journal_setup): def test_middleware_tag(self, journal_setup):
j, _, _ = journal_setup j, _ = journal_setup
assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization" assert j._identify_caller({"tags": ["middleware:summarization"]}) == "middleware:summarization"
def test_no_tags_returns_unknown(self, journal_setup): def test_no_tags_returns_lead_agent(self, journal_setup):
j, _, _ = journal_setup j, _ = journal_setup
assert j._identify_caller({"tags": []}) == "unknown" assert j._identify_caller({"tags": []}) == "lead_agent"
assert j._identify_caller({}) == "unknown" assert j._identify_caller({}) == "lead_agent"
class TestChainErrorCallback: class TestChainErrorCallback:
@pytest.mark.anyio @pytest.mark.anyio
async def test_on_chain_error_writes_run_error(self, journal_setup): async def test_on_chain_error_writes_run_error(self, journal_setup):
j, store, _ = journal_setup j, store = journal_setup
# parent_run_id must be None (top-level chain) for the event to be recorded
j.on_chain_error(ValueError("boom"), run_id=uuid4(), parent_run_id=None) j.on_chain_error(ValueError("boom"), run_id=uuid4(), parent_run_id=None)
# on_chain_error calls _flush_sync internally, give async task time to complete
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
await j.flush() await j.flush()
events = await store.list_events("t1", "r1") events = await store.list_events("t1", "r1")
@@ -232,85 +295,125 @@ class TestChainErrorCallback:
class TestTokenTrackingDisabled: class TestTokenTrackingDisabled:
@pytest.mark.anyio @pytest.mark.anyio
async def test_track_token_usage_false(self): async def test_track_token_usage_false(self):
"""track_token_usage=False disables token accumulation."""
store = MemoryRunEventStore() store = MemoryRunEventStore()
complete_data = {} j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
j = RunJournal("r1", "t1", store, track_token_usage=False, on_complete=lambda **d: complete_data.update(d), flush_threshold=100) j.on_llm_end(
j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}), run_id=uuid4(), tags=["lead_agent"]) _make_llm_response("X", usage={"input_tokens": 50, "output_tokens": 50, "total_tokens": 100}),
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) run_id=uuid4(),
assert complete_data["total_tokens"] == 0 tags=["lead_agent"],
assert complete_data["llm_call_count"] == 0 )
data = j.get_completion_data()
assert data["total_tokens"] == 0
class TestMiddlewareNoMessage: assert data["llm_call_count"] == 0
@pytest.mark.anyio
async def test_on_llm_end_middleware_no_ai_message(self, journal_setup):
j, store, _ = journal_setup
j.on_llm_end(_make_llm_response("Summary"), run_id=uuid4(), tags=["middleware:summarization"])
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
class TestUnknownCallerTokens:
@pytest.mark.anyio
async def test_unknown_caller_tokens_go_to_lead(self, journal_setup):
"""No caller tag: tokens attributed to lead_agent bucket."""
j, store, _ = journal_setup
j.on_llm_end(_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), run_id=uuid4(), tags=[])
assert j._lead_agent_tokens == 15
class TestConvenienceFields: class TestConvenienceFields:
@pytest.mark.anyio @pytest.mark.anyio
async def test_last_ai_message_tracks_latest(self, journal_setup): async def test_last_ai_message_tracks_latest(self, journal_setup):
j, store, complete_data = journal_setup j, store = journal_setup
j.on_llm_end(_make_llm_response("First"), run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("First"), run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Second"), run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("Second"), run_id=uuid4(), tags=["lead_agent"])
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) data = j.get_completion_data()
assert complete_data["last_ai_message"] == "Second" assert data["last_ai_message"] == "Second"
assert complete_data["message_count"] == 2 assert data["message_count"] == 2
@pytest.mark.anyio @pytest.mark.anyio
async def test_first_human_message_via_set(self, journal_setup): async def test_first_human_message_via_set(self, journal_setup):
j, store, complete_data = journal_setup j, _ = journal_setup
j.set_first_human_message("What is AI?") j.set_first_human_message("What is AI?")
j.on_chain_end({}, run_id=uuid4(), parent_run_id=None) data = j.get_completion_data()
assert complete_data["first_human_message"] == "What is AI?" assert data["first_human_message"] == "What is AI?"
class TestToolError:
@pytest.mark.anyio
async def test_on_tool_error(self, journal_setup):
j, store, _ = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch")
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "tool_error" for e in events)
class TestOtherCustomEvent:
@pytest.mark.anyio
async def test_non_summarization_custom_event(self, journal_setup):
j, store, _ = journal_setup
j.on_custom_event("task_running", {"task_id": "t1", "status": "running"}, run_id=uuid4())
await j.flush()
events = await store.list_events("t1", "r1")
assert any(e["event_type"] == "task_running" for e in events)
class TestPublicMethods:
@pytest.mark.anyio
async def test_set_first_human_message(self, journal_setup):
j, _, _ = journal_setup
j.set_first_human_message("Hello world")
assert j._first_human_msg == "Hello world"
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_completion_data(self, journal_setup): async def test_get_completion_data(self, journal_setup):
j, _, _ = journal_setup j, _ = journal_setup
j._total_tokens = 100 j._total_tokens = 100
j._msg_count = 5 j._msg_count = 5
data = j.get_completion_data() data = j.get_completion_data()
assert data["total_tokens"] == 100 assert data["total_tokens"] == 100
assert data["message_count"] == 5 assert data["message_count"] == 5
class TestUnknownCallerTokens:
@pytest.mark.anyio
async def test_unknown_caller_tokens_go_to_lead(self, journal_setup):
j, store = journal_setup
j.on_llm_end(
_make_llm_response("X", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
tags=[],
)
assert j._lead_agent_tokens == 15
# ---------------------------------------------------------------------------
# SQLite-backed end-to-end test
# ---------------------------------------------------------------------------
class TestDbBackedLifecycle:
@pytest.mark.anyio
async def test_full_lifecycle_with_sqlite(self, tmp_path):
"""Full lifecycle with SQLite-backed RunRepository + DbRunEventStore."""
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
from deerflow.persistence.repositories.run_repo import RunRepository
from deerflow.runtime.events.store.db import DbRunEventStore
from deerflow.runtime.runs.manager import RunManager
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
sf = get_session_factory()
run_store = RunRepository(sf)
event_store = DbRunEventStore(sf)
mgr = RunManager(store=run_store)
# Create run
record = await mgr.create("t1", "lead_agent")
run_id = record.run_id
# Write human_message
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content="Hello DB")
# Simulate journal
journal = RunJournal(run_id, "t1", event_store, flush_threshold=100)
journal.set_first_human_message("Hello DB")
journal.on_chain_start({}, {}, run_id=uuid4(), parent_run_id=None)
llm_rid = uuid4()
journal.on_llm_start({"name": "test"}, [], run_id=llm_rid, tags=["lead_agent"])
journal.on_llm_end(
_make_llm_response("DB response", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=llm_rid,
tags=["lead_agent"],
)
journal.on_chain_end({}, run_id=uuid4(), parent_run_id=None)
await asyncio.sleep(0.05)
await journal.flush()
# Verify run persisted
row = await run_store.get(run_id)
assert row is not None
assert row["status"] == "pending"
# Update completion
completion = journal.get_completion_data()
await run_store.update_run_completion(run_id, status="success", **completion)
row = await run_store.get(run_id)
assert row["status"] == "success"
assert row["total_tokens"] == 15
# Verify messages from DB
messages = await event_store.list_messages("t1")
assert len(messages) == 2
assert messages[0]["event_type"] == "human_message"
assert messages[1]["event_type"] == "ai_message"
# Verify events from DB
events = await event_store.list_events("t1", run_id)
event_types = {e["event_type"] for e in events}
assert "run_start" in event_types
assert "llm_end" in event_types
assert "run_end" in event_types
await close_engine()