fix(runtime): persist run message summaries (#2850)

* fix(runtime): persist run message summaries (#2849)

* fix(runtime): dedupe run message summaries
This commit is contained in:
Nan Gao
2026-05-11 13:54:00 +02:00
committed by GitHub
parent c3bc6c7cd5
commit 2eb11f97ab
2 changed files with 148 additions and 1 deletions
@@ -20,12 +20,13 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time import time
from collections.abc import Mapping
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import UUID from uuid import UUID
from langchain_core.callbacks import BaseCallbackHandler from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage, ToolMessage
from langgraph.types import Command from langgraph.types import Command
if TYPE_CHECKING: if TYPE_CHECKING:
@@ -71,6 +72,7 @@ class RunJournal(BaseCallbackHandler):
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id # Dedup: LangChain may fire on_llm_end multiple times for the same run_id
self._counted_llm_run_ids: set[str] = set() self._counted_llm_run_ids: set[str] = set()
self._counted_external_source_ids: set[str] = set() self._counted_external_source_ids: set[str] = set()
self._counted_message_llm_run_ids: set[str] = set()
# Convenience fields # Convenience fields
self._last_ai_msg: str | None = None self._last_ai_msg: str | None = None
@@ -86,6 +88,50 @@ class RunJournal(BaseCallbackHandler):
# -- Lifecycle callbacks -- # -- Lifecycle callbacks --
@staticmethod
def _message_text(message: BaseMessage) -> str:
"""Extract displayable text from a message's mixed content shape."""
content = getattr(message, "content", None)
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for block in content:
if isinstance(block, str):
parts.append(block)
elif isinstance(block, Mapping):
text = block.get("text")
if isinstance(text, str):
parts.append(text)
else:
nested = block.get("content")
if isinstance(nested, str):
parts.append(nested)
return "".join(parts)
if isinstance(content, Mapping):
for key in ("text", "content"):
value = content.get(key)
if isinstance(value, str):
return value
text = getattr(message, "text", None)
if isinstance(text, str):
return text
return ""
def _record_message_summary(self, message: BaseMessage, *, caller: str | None = None) -> None:
"""Update run-level convenience fields for persisted run rows."""
self._msg_count += 1
# ``last_ai_message`` should represent the lead agent's user-facing
# answer. Middleware/subagent model calls and empty tool-call-only
# AI messages must not overwrite the last useful assistant text.
is_ai_message = isinstance(message, AIMessage) or getattr(message, "type", None) == "ai"
if is_ai_message and (caller is None or caller == "lead_agent"):
text = self._message_text(message).strip()
if text:
self._last_ai_msg = text[:2000]
def on_chain_start( def on_chain_start(
self, self,
serialized: dict[str, Any], serialized: dict[str, Any],
@@ -164,6 +210,7 @@ class RunJournal(BaseCallbackHandler):
content=m.model_dump(), content=m.model_dump(),
metadata={"caller": caller}, metadata={"caller": caller},
) )
self._record_message_summary(m, caller=caller)
break break
if self._first_human_msg: if self._first_human_msg:
break break
@@ -222,6 +269,8 @@ class RunJournal(BaseCallbackHandler):
"llm_call_index": call_index, "llm_call_index": call_index,
}, },
) )
if rid not in self._counted_message_llm_run_ids:
self._record_message_summary(message, caller=caller)
# Token accumulation (dedup by langchain run_id to avoid double-counting # Token accumulation (dedup by langchain run_id to avoid double-counting
# when the callback fires more than once for the same response) # when the callback fires more than once for the same response)
@@ -245,6 +294,9 @@ class RunJournal(BaseCallbackHandler):
else: else:
self._lead_agent_tokens += total_tk self._lead_agent_tokens += total_tk
if messages:
self._counted_message_llm_run_ids.add(str(run_id))
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None) self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm.error", category="trace", content=str(error)) self._put(event_type="llm.error", category="trace", content=str(error))
@@ -260,12 +312,14 @@ class RunJournal(BaseCallbackHandler):
if isinstance(output, ToolMessage): if isinstance(output, ToolMessage):
msg = cast(ToolMessage, output) msg = cast(ToolMessage, output)
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump()) self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
self._record_message_summary(msg)
elif isinstance(output, Command): elif isinstance(output, Command):
cmd = cast(Command, output) cmd = cast(Command, output)
messages = cmd.update.get("messages", []) messages = cmd.update.get("messages", [])
for message in messages: for message in messages:
if isinstance(message, BaseMessage): if isinstance(message, BaseMessage):
self._put(event_type="llm.tool.result", category="message", content=message.model_dump()) self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
self._record_message_summary(message)
else: else:
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}") logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
else: else:
+93
View File
@@ -339,6 +339,99 @@ class TestConvenienceFields:
data = j.get_completion_data() data = j.get_completion_data()
assert data["first_human_message"] == "What is AI?" assert data["first_human_message"] == "What is AI?"
@pytest.mark.anyio
async def test_completion_data_counts_human_ai_and_tool_messages(self, journal_setup):
from langchain_core.messages import HumanMessage, ToolMessage
j, _ = journal_setup
j.on_chat_model_start({}, [[HumanMessage(content="Question")]], run_id=uuid4(), tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_tool_end(ToolMessage(content="Tool result", tool_call_id="call_1", name="search"), run_id=uuid4())
data = j.get_completion_data()
assert data["message_count"] == 3
assert data["first_human_message"] == "Question"
assert data["last_ai_message"] == "Answer"
@pytest.mark.anyio
async def test_tool_call_only_ai_does_not_clear_last_ai_message(self, journal_setup):
j, _ = journal_setup
j.on_llm_end(_make_llm_response("Useful answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(
_make_llm_response("", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["message_count"] == 2
assert data["last_ai_message"] == "Useful answer"
@pytest.mark.anyio
async def test_last_ai_message_extracts_mixed_content_without_extra_newlines(self, journal_setup):
j, _ = journal_setup
j.on_llm_end(
_make_llm_response(
[
{"type": "text", "text": "First "},
{"type": "text", "content": "second"},
" third",
{"type": "image", "url": "ignored"},
]
),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["message_count"] == 1
assert data["last_ai_message"] == "First second third"
@pytest.mark.anyio
async def test_last_ai_message_extracts_mapping_content(self, journal_setup):
j, _ = journal_setup
j.on_llm_end(_make_llm_response({"content": "Nested answer"}), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
data = j.get_completion_data()
assert data["message_count"] == 1
assert data["last_ai_message"] == "Nested answer"
@pytest.mark.anyio
async def test_duplicate_llm_run_id_does_not_double_count_message_summary(self, journal_setup):
j, _ = journal_setup
run_id = uuid4()
j.on_llm_end(_make_llm_response("Answer", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(
_make_llm_response("Answer", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=run_id,
parent_run_id=None,
tags=["lead_agent"],
)
data = j.get_completion_data()
assert data["message_count"] == 1
assert data["last_ai_message"] == "Answer"
assert data["total_tokens"] == 15
@pytest.mark.anyio
async def test_subagent_ai_does_not_overwrite_lead_last_ai_message(self, journal_setup):
j, _ = journal_setup
j.on_llm_end(_make_llm_response("Lead answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("Subagent detail"), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
data = j.get_completion_data()
assert data["message_count"] == 2
assert data["last_ai_message"] == "Lead answer"
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_completion_data(self, journal_setup): async def test_get_completion_data(self, journal_setup):
j, _ = journal_setup j, _ = journal_setup