refactor(journal): fix flush, token tracking, and consolidate tests

RunJournal fixes:
- _flush_sync: retain events in buffer when no event loop instead of
  dropping them; worker's finally block flushes via async flush().
- on_llm_end: add tool_calls filter and caller=="lead_agent" guard for
  ai_message events; mark message IDs for dedup with record_llm_usage.
- worker.py: persist completion data (tokens, message count) to RunStore
  in finally block.

Model factory:
- Auto-inject stream_usage=True for BaseChatOpenAI subclasses with
  custom api_base, so usage_metadata is populated in streaming responses.

Test consolidation:
- Delete test_phase2b_integration.py (redundant with existing tests).
- Move DB-backed lifecycle test into test_run_journal.py.
- Add tests for stream_usage injection in test_model_factory.py.
- Clean up executor/task_tool dead journal references.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-03 17:26:11 +08:00
parent e5b01d7e74
commit b92ddafd4b
7 changed files with 360 additions and 451 deletions
@@ -77,6 +77,15 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
elif "reasoning_effort" not in model_settings_from_config:
model_settings_from_config["reasoning_effort"] = "medium"
# Ensure stream_usage is enabled so that token usage metadata is available
# in streaming responses. LangChain's BaseChatOpenAI only defaults
# stream_usage=True when no custom base_url/api_base is set, so models
# hitting third-party endpoints (e.g. doubao, deepseek) silently lose
# usage data. We default it to True unless explicitly configured.
if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs:
if "stream_usage" in getattr(model_class, "model_fields", {}):
model_settings_from_config["stream_usage"] = True
model_instance = model_class(**kwargs, **model_settings_from_config)
if is_tracing_enabled():
@@ -16,7 +16,6 @@ from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Callable
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any
from uuid import UUID
@@ -39,7 +38,6 @@ class RunJournal(BaseCallbackHandler):
event_store: RunEventStore,
*,
track_token_usage: bool = True,
on_complete: Callable[..., Any] | None = None,
flush_threshold: int = 20,
):
super().__init__()
@@ -47,7 +45,6 @@ class RunJournal(BaseCallbackHandler):
self.thread_id = thread_id
self._store = event_store
self._track_tokens = track_token_usage
self._on_complete = on_complete
self._flush_threshold = flush_threshold
# Write buffer
@@ -73,7 +70,6 @@ class RunJournal(BaseCallbackHandler):
# -- Lifecycle callbacks --
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
# Only record for the top-level chain (parent_run_id is None)
if kwargs.get("parent_run_id") is not None:
return
self._put(
@@ -87,19 +83,6 @@ class RunJournal(BaseCallbackHandler):
return
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
self._flush_sync()
if self._on_complete:
self._on_complete(
total_input_tokens=self._total_input_tokens,
total_output_tokens=self._total_output_tokens,
total_tokens=self._total_tokens,
llm_call_count=self._llm_call_count,
lead_agent_tokens=self._lead_agent_tokens,
subagent_tokens=self._subagent_tokens,
middleware_tokens=self._middleware_tokens,
message_count=self._msg_count,
last_ai_message=self._last_ai_msg,
first_human_message=self._first_human_msg,
)
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
if kwargs.get("parent_run_id") is not None:
@@ -131,7 +114,6 @@ class RunJournal(BaseCallbackHandler):
logger.debug("on_llm_end: could not extract message from response")
return
serialized_msg = serialize_lc_object(message)
caller = self._identify_caller(kwargs)
# Latency
@@ -142,54 +124,52 @@ class RunJournal(BaseCallbackHandler):
usage = getattr(message, "usage_metadata", None)
usage_dict = dict(usage) if usage else {}
# trace event: llm_end (every LLM call)
# Trace event: llm_end (every LLM call)
content = getattr(message, "content", "")
self._put(
event_type="llm_end",
category="trace",
content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")),
content=content if isinstance(content, str) else str(content),
metadata={
"message": serialized_msg,
"message": serialize_lc_object(message),
"caller": caller,
"usage": usage_dict,
"latency_ms": latency_ms,
},
)
# message event: ai_message (only lead_agent final replies with content)
if caller == "lead_agent":
content = getattr(message, "content", "")
if isinstance(content, str) and content:
tool_calls = getattr(message, "tool_calls", None) or []
tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)]
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
self._put(
event_type="ai_message",
category="message",
content=content,
metadata={
"model_name": model_name,
"tool_calls": tool_calls_summary,
},
)
self._last_ai_msg = content[:2000]
self._msg_count += 1
# Message event: ai_message (only lead_agent final replies — no pending tool_calls)
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent" and isinstance(content, str) and content and not tool_calls:
resp_meta = getattr(message, "response_metadata", None) or {}
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
self._put(
event_type="ai_message",
category="message",
content=content,
metadata={"model_name": model_name},
)
self._last_ai_msg = content[:2000]
self._msg_count += 1
# Token accumulation
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if self._track_tokens and total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
@@ -277,20 +257,23 @@ class RunJournal(BaseCallbackHandler):
self._flush_sync()
def _flush_sync(self) -> None:
"""Flush buffer to RunEventStore.
"""Best-effort flush of buffer to RunEventStore.
BaseCallbackHandler methods are synchronous. We schedule the async
put_batch via the current event loop.
BaseCallbackHandler methods are synchronous. If an event loop is
running we schedule an async ``put_batch``; otherwise the events
stay in the buffer and are flushed later by the async ``flush()``
call in the worker's ``finally`` block.
"""
if not self._buffer:
return
batch = self._buffer.copy()
self._buffer.clear()
try:
loop = asyncio.get_running_loop()
loop.create_task(self._flush_async(batch))
except RuntimeError:
logger.warning("RunJournal: no event loop, dropping %d events", len(batch))
# No event loop — keep events in buffer for later async flush.
return
batch = self._buffer.copy()
self._buffer.clear()
loop.create_task(self._flush_async(batch))
async def _flush_async(self, batch: list[dict]) -> None:
try:
@@ -302,7 +285,10 @@ class RunJournal(BaseCallbackHandler):
for tag in kwargs.get("tags") or []:
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
return tag
return "unknown"
# Default to lead_agent: the main agent graph does not inject
# callback tags, while subagents and middleware explicitly tag
# themselves.
return "lead_agent"
# -- Public methods (called by worker) --
@@ -311,7 +297,7 @@ class RunJournal(BaseCallbackHandler):
self._first_human_msg = content[:2000] if content else None
async def flush(self) -> None:
"""Force flush. Used in cancel/error paths."""
"""Force flush remaining buffer. Called in worker's finally block."""
if self._buffer:
batch = self._buffer.copy()
self._buffer.clear()
@@ -123,7 +123,8 @@ async def run_agent(
runtime = Runtime(context={"thread_id": thread_id}, store=store)
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
# Inject RunJournal as a callback
# Inject RunJournal as a LangChain callback handler.
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
if journal is not None:
config.setdefault("callbacks", []).append(journal)
@@ -241,13 +242,25 @@ async def run_agent(
)
finally:
# Flush any buffered journal events
# Flush any buffered journal events and persist completion data
if journal is not None:
try:
await journal.flush()
except Exception:
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
# Persist token usage + convenience fields to RunStore
if run_manager._store is not None:
try:
completion = journal.get_completion_data()
await run_manager._store.update_run_completion(
run_id,
status=record.status.value,
**completion,
)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60))