fix(runtime): persist run message summaries (#2850)
* fix(runtime): persist run message summaries (#2849) * fix(runtime): dedupe run message summaries
This commit is contained in:
@@ -20,12 +20,13 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||
from langgraph.types import Command
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -71,6 +72,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
|
||||
self._counted_llm_run_ids: set[str] = set()
|
||||
self._counted_external_source_ids: set[str] = set()
|
||||
self._counted_message_llm_run_ids: set[str] = set()
|
||||
|
||||
# Convenience fields
|
||||
self._last_ai_msg: str | None = None
|
||||
@@ -86,6 +88,50 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# -- Lifecycle callbacks --
|
||||
|
||||
@staticmethod
|
||||
def _message_text(message: BaseMessage) -> str:
|
||||
"""Extract displayable text from a message's mixed content shape."""
|
||||
content = getattr(message, "content", None)
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for block in content:
|
||||
if isinstance(block, str):
|
||||
parts.append(block)
|
||||
elif isinstance(block, Mapping):
|
||||
text = block.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
else:
|
||||
nested = block.get("content")
|
||||
if isinstance(nested, str):
|
||||
parts.append(nested)
|
||||
return "".join(parts)
|
||||
if isinstance(content, Mapping):
|
||||
for key in ("text", "content"):
|
||||
value = content.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
text = getattr(message, "text", None)
|
||||
if isinstance(text, str):
|
||||
return text
|
||||
return ""
|
||||
|
||||
def _record_message_summary(self, message: BaseMessage, *, caller: str | None = None) -> None:
|
||||
"""Update run-level convenience fields for persisted run rows."""
|
||||
self._msg_count += 1
|
||||
|
||||
# ``last_ai_message`` should represent the lead agent's user-facing
|
||||
# answer. Middleware/subagent model calls and empty tool-call-only
|
||||
# AI messages must not overwrite the last useful assistant text.
|
||||
is_ai_message = isinstance(message, AIMessage) or getattr(message, "type", None) == "ai"
|
||||
if is_ai_message and (caller is None or caller == "lead_agent"):
|
||||
text = self._message_text(message).strip()
|
||||
if text:
|
||||
self._last_ai_msg = text[:2000]
|
||||
|
||||
def on_chain_start(
|
||||
self,
|
||||
serialized: dict[str, Any],
|
||||
@@ -164,6 +210,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
content=m.model_dump(),
|
||||
metadata={"caller": caller},
|
||||
)
|
||||
self._record_message_summary(m, caller=caller)
|
||||
break
|
||||
if self._first_human_msg:
|
||||
break
|
||||
@@ -222,6 +269,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
"llm_call_index": call_index,
|
||||
},
|
||||
)
|
||||
if rid not in self._counted_message_llm_run_ids:
|
||||
self._record_message_summary(message, caller=caller)
|
||||
|
||||
# Token accumulation (dedup by langchain run_id to avoid double-counting
|
||||
# when the callback fires more than once for the same response)
|
||||
@@ -245,6 +294,9 @@ class RunJournal(BaseCallbackHandler):
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
if messages:
|
||||
self._counted_message_llm_run_ids.add(str(run_id))
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||
@@ -260,12 +312,14 @@ class RunJournal(BaseCallbackHandler):
|
||||
if isinstance(output, ToolMessage):
|
||||
msg = cast(ToolMessage, output)
|
||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||
self._record_message_summary(msg)
|
||||
elif isinstance(output, Command):
|
||||
cmd = cast(Command, output)
|
||||
messages = cmd.update.get("messages", [])
|
||||
for message in messages:
|
||||
if isinstance(message, BaseMessage):
|
||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||
self._record_message_summary(message)
|
||||
else:
|
||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user