mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
refactor(journal): fix flush, token tracking, and consolidate tests
RunJournal fixes: - _flush_sync: retain events in buffer when no event loop instead of dropping them; worker's finally block flushes via async flush(). - on_llm_end: add tool_calls filter and caller=="lead_agent" guard for ai_message events; mark message IDs for dedup with record_llm_usage. - worker.py: persist completion data (tokens, message count) to RunStore in finally block. Model factory: - Auto-inject stream_usage=True for BaseChatOpenAI subclasses with custom api_base, so usage_metadata is populated in streaming responses. Test consolidation: - Delete test_phase2b_integration.py (redundant with existing tests). - Move DB-backed lifecycle test into test_run_journal.py. - Add tests for stream_usage injection in test_model_factory.py. - Clean up executor/task_tool dead journal references. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -77,6 +77,15 @@ def create_chat_model(name: str | None = None, thinking_enabled: bool = False, *
|
||||
elif "reasoning_effort" not in model_settings_from_config:
|
||||
model_settings_from_config["reasoning_effort"] = "medium"
|
||||
|
||||
# Ensure stream_usage is enabled so that token usage metadata is available
|
||||
# in streaming responses. LangChain's BaseChatOpenAI only defaults
|
||||
# stream_usage=True when no custom base_url/api_base is set, so models
|
||||
# hitting third-party endpoints (e.g. doubao, deepseek) silently lose
|
||||
# usage data. We default it to True unless explicitly configured.
|
||||
if "stream_usage" not in model_settings_from_config and "stream_usage" not in kwargs:
|
||||
if "stream_usage" in getattr(model_class, "model_fields", {}):
|
||||
model_settings_from_config["stream_usage"] = True
|
||||
|
||||
model_instance = model_class(**kwargs, **model_settings_from_config)
|
||||
|
||||
if is_tracing_enabled():
|
||||
|
||||
@@ -16,7 +16,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from uuid import UUID
|
||||
@@ -39,7 +38,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
event_store: RunEventStore,
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
on_complete: Callable[..., Any] | None = None,
|
||||
flush_threshold: int = 20,
|
||||
):
|
||||
super().__init__()
|
||||
@@ -47,7 +45,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
self.thread_id = thread_id
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._on_complete = on_complete
|
||||
self._flush_threshold = flush_threshold
|
||||
|
||||
# Write buffer
|
||||
@@ -73,7 +70,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
def on_chain_start(self, serialized: dict, inputs: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
# Only record for the top-level chain (parent_run_id is None)
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
return
|
||||
self._put(
|
||||
@@ -87,19 +83,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
return
|
||||
self._put(event_type="run_end", category="lifecycle", metadata={"status": "success"})
|
||||
self._flush_sync()
|
||||
if self._on_complete:
|
||||
self._on_complete(
|
||||
total_input_tokens=self._total_input_tokens,
|
||||
total_output_tokens=self._total_output_tokens,
|
||||
total_tokens=self._total_tokens,
|
||||
llm_call_count=self._llm_call_count,
|
||||
lead_agent_tokens=self._lead_agent_tokens,
|
||||
subagent_tokens=self._subagent_tokens,
|
||||
middleware_tokens=self._middleware_tokens,
|
||||
message_count=self._msg_count,
|
||||
last_ai_message=self._last_ai_msg,
|
||||
first_human_message=self._first_human_msg,
|
||||
)
|
||||
|
||||
def on_chain_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
if kwargs.get("parent_run_id") is not None:
|
||||
@@ -131,7 +114,6 @@ class RunJournal(BaseCallbackHandler):
|
||||
logger.debug("on_llm_end: could not extract message from response")
|
||||
return
|
||||
|
||||
serialized_msg = serialize_lc_object(message)
|
||||
caller = self._identify_caller(kwargs)
|
||||
|
||||
# Latency
|
||||
@@ -142,54 +124,52 @@ class RunJournal(BaseCallbackHandler):
|
||||
usage = getattr(message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
|
||||
# trace event: llm_end (every LLM call)
|
||||
# Trace event: llm_end (every LLM call)
|
||||
content = getattr(message, "content", "")
|
||||
self._put(
|
||||
event_type="llm_end",
|
||||
category="trace",
|
||||
content=getattr(message, "content", "") if isinstance(getattr(message, "content", ""), str) else str(getattr(message, "content", "")),
|
||||
content=content if isinstance(content, str) else str(content),
|
||||
metadata={
|
||||
"message": serialized_msg,
|
||||
"message": serialize_lc_object(message),
|
||||
"caller": caller,
|
||||
"usage": usage_dict,
|
||||
"latency_ms": latency_ms,
|
||||
},
|
||||
)
|
||||
|
||||
# message event: ai_message (only lead_agent final replies with content)
|
||||
if caller == "lead_agent":
|
||||
content = getattr(message, "content", "")
|
||||
if isinstance(content, str) and content:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
tool_calls_summary = [{"name": tc.get("name", ""), "status": "success"} for tc in tool_calls if isinstance(tc, dict)]
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=content,
|
||||
metadata={
|
||||
"model_name": model_name,
|
||||
"tool_calls": tool_calls_summary,
|
||||
},
|
||||
)
|
||||
self._last_ai_msg = content[:2000]
|
||||
self._msg_count += 1
|
||||
# Message event: ai_message (only lead_agent final replies — no pending tool_calls)
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if caller == "lead_agent" and isinstance(content, str) and content and not tool_calls:
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
model_name = resp_meta.get("model_name") if isinstance(resp_meta, dict) else None
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content=content,
|
||||
metadata={"model_name": model_name},
|
||||
)
|
||||
self._last_ai_msg = content[:2000]
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if self._track_tokens and total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
@@ -277,20 +257,23 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._flush_sync()
|
||||
|
||||
def _flush_sync(self) -> None:
|
||||
"""Flush buffer to RunEventStore.
|
||||
"""Best-effort flush of buffer to RunEventStore.
|
||||
|
||||
BaseCallbackHandler methods are synchronous. We schedule the async
|
||||
put_batch via the current event loop.
|
||||
BaseCallbackHandler methods are synchronous. If an event loop is
|
||||
running we schedule an async ``put_batch``; otherwise the events
|
||||
stay in the buffer and are flushed later by the async ``flush()``
|
||||
call in the worker's ``finally`` block.
|
||||
"""
|
||||
if not self._buffer:
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.create_task(self._flush_async(batch))
|
||||
except RuntimeError:
|
||||
logger.warning("RunJournal: no event loop, dropping %d events", len(batch))
|
||||
# No event loop — keep events in buffer for later async flush.
|
||||
return
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
loop.create_task(self._flush_async(batch))
|
||||
|
||||
async def _flush_async(self, batch: list[dict]) -> None:
|
||||
try:
|
||||
@@ -302,7 +285,10 @@ class RunJournal(BaseCallbackHandler):
|
||||
for tag in kwargs.get("tags") or []:
|
||||
if isinstance(tag, str) and (tag.startswith("subagent:") or tag.startswith("middleware:") or tag == "lead_agent"):
|
||||
return tag
|
||||
return "unknown"
|
||||
# Default to lead_agent: the main agent graph does not inject
|
||||
# callback tags, while subagents and middleware explicitly tag
|
||||
# themselves.
|
||||
return "lead_agent"
|
||||
|
||||
# -- Public methods (called by worker) --
|
||||
|
||||
@@ -311,7 +297,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
|
||||
async def flush(self) -> None:
|
||||
"""Force flush. Used in cancel/error paths."""
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._buffer:
|
||||
batch = self._buffer.copy()
|
||||
self._buffer.clear()
|
||||
|
||||
@@ -123,7 +123,8 @@ async def run_agent(
|
||||
runtime = Runtime(context={"thread_id": thread_id}, store=store)
|
||||
config.setdefault("configurable", {})["__pregel_runtime"] = runtime
|
||||
|
||||
# Inject RunJournal as a callback
|
||||
# Inject RunJournal as a LangChain callback handler.
|
||||
# on_llm_end captures token usage; on_chain_start/end captures lifecycle.
|
||||
if journal is not None:
|
||||
config.setdefault("callbacks", []).append(journal)
|
||||
|
||||
@@ -241,13 +242,25 @@ async def run_agent(
|
||||
)
|
||||
|
||||
finally:
|
||||
# Flush any buffered journal events
|
||||
# Flush any buffered journal events and persist completion data
|
||||
if journal is not None:
|
||||
try:
|
||||
await journal.flush()
|
||||
except Exception:
|
||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||
|
||||
# Persist token usage + convenience fields to RunStore
|
||||
if run_manager._store is not None:
|
||||
try:
|
||||
completion = journal.get_completion_data()
|
||||
await run_manager._store.update_run_completion(
|
||||
run_id,
|
||||
status=record.status.value,
|
||||
**completion,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||
|
||||
await bridge.publish_end(run_id)
|
||||
asyncio.create_task(bridge.cleanup(run_id, delay=60))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user