feat(events): align message events with checkpoint format and add middleware tag injection

- Message events (ai_message, ai_tool_call, tool_result, human_message) now use
  BaseMessage.model_dump() format, matching LangGraph checkpoint values.messages
- on_tool_end extracts tool_call_id/name/status from ToolMessage objects
- on_tool_error now emits tool_result message events with error status
- record_middleware uses middleware:{tag} event_type and middleware category
- Summarization custom events use middleware:summarize category
- TitleMiddleware injects middleware:title tag via get_config() inheritance
- SummarizationMiddleware model bound with middleware:summarize tag
- Worker writes human_message using HumanMessage.model_dump()

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-04 20:52:27 +08:00
parent 2d135aad0f
commit 52e7acafee
6 changed files with 356 additions and 98 deletions
@@ -179,7 +179,8 @@ class RunJournal(BaseCallbackHandler):
},
)
# Message events: only lead_agent gets message-category events
# Message events: only lead_agent gets message-category events.
# Content uses message.model_dump() to align with checkpoint format.
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent":
resp_meta = getattr(message, "response_metadata", None) or {}
@@ -189,7 +190,7 @@ class RunJournal(BaseCallbackHandler):
self._put(
event_type="ai_tool_call",
category="message",
content=langchain_to_openai_message(message),
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
)
elif isinstance(content, str) and content:
@@ -197,10 +198,10 @@ class RunJournal(BaseCallbackHandler):
self._put(
event_type="ai_message",
category="message",
content={"role": "assistant", "content": content},
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "stop"},
)
self._last_ai_msg = content[:2000]
self._last_ai_msg = content
self._msg_count += 1
# Token accumulation
@@ -242,45 +243,87 @@ class RunJournal(BaseCallbackHandler):
},
)
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
# Extract fields from ToolMessage object when LangChain provides one.
# LangChain's _format_output wraps tool results into a ToolMessage
# with tool_call_id, name, status, and artifact — more complete than
# what kwargs alone provides.
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
# Use model_dump() for checkpoint-aligned message content.
# Override tool_call_id if it was resolved from cache.
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
# Construct checkpoint-aligned dict when output is a plain string.
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
# Trace event (always)
self._put(
event_type="tool_end",
category="trace",
content=str(output),
content=content_str,
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
"status": "success",
"status": status,
},
)
# Message event: tool_result
# Message event: tool_result (checkpoint-aligned model_dump format)
self._put(
event_type="tool_result",
category="message",
content={
"role": "tool",
"tool_call_id": tool_call_id or "",
"content": str(output),
},
metadata={"tool_name": tool_name},
content=msg_content,
metadata={"tool_name": tool_name, "status": status},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
# Trace event
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={
"tool_name": kwargs.get("name", ""),
"tool_call_id": kwargs.get("tool_call_id"),
"tool_name": tool_name,
"tool_call_id": tool_call_id,
},
)
# Message event: tool_result with error status (checkpoint-aligned)
msg_content = ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump()
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": "error"},
)
# -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
@@ -298,8 +341,8 @@ class RunJournal(BaseCallbackHandler):
},
)
self._put(
event_type="summary",
category="message",
event_type="middleware:summarize",
category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
@@ -366,16 +409,24 @@ class RunJournal(BaseCallbackHandler):
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
def record_middleware(self, name: str, hook: str, action: str, changes: dict) -> None:
"""Record a middleware trace event.
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
"""Record a middleware state-change event.
Called by middleware implementations when they perform a meaningful
state change (e.g., title generation, summarization, HITL approval).
Pure-observation middleware should not call this.
Args:
tag: Short identifier for the middleware (e.g., "title", "summarize",
"guardrail"). Used to form event_type="middleware:{tag}".
name: Full middleware class name.
hook: Lifecycle hook that triggered the action (e.g., "after_model").
action: Specific action performed (e.g., "generate_title").
changes: Dict describing the state changes made.
"""
self._put(
event_type="middleware",
category="trace",
event_type=f"middleware:{tag}",
category="middleware",
content={"name": name, "hook": hook, "action": action, "changes": changes},
)
@@ -67,9 +67,9 @@ async def run_agent(
track_token_usage=getattr(run_events_config, "track_token_usage", True),
)
# Write human_message event
user_input = _extract_user_input(graph_input)
if user_input:
# Write human_message event (model_dump format, aligned with checkpoint)
human_msg = _extract_human_message(graph_input)
if human_msg is not None:
msg_metadata = {}
if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
@@ -78,10 +78,11 @@ async def run_agent(
run_id=run_id,
event_type="human_message",
category="message",
content={"role": "user", "content": user_input},
content=human_msg.model_dump(),
metadata=msg_metadata or None,
)
journal.set_first_human_message(user_input)
content = human_msg.content
journal.set_first_human_message(content if isinstance(content, str) else str(content))
# Track whether "events" was requested but skipped
if "events" in requested_modes:
@@ -282,21 +283,29 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode
def _extract_user_input(graph_input: dict) -> str:
"""Extract user input text from graph_input for event recording."""
def _extract_human_message(graph_input: dict) -> "HumanMessage | None":
"""Extract or construct a HumanMessage from graph_input for event recording.
Returns a LangChain HumanMessage so callers can use .model_dump() to get
the checkpoint-aligned serialization format.
"""
from langchain_core.messages import HumanMessage
messages = graph_input.get("messages")
if not messages:
return ""
# Take the last message (usually the user's input)
return None
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, str):
if isinstance(last, HumanMessage):
return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"):
content = last.content
return content if isinstance(content, str) else str(content)
return HumanMessage(content=content)
if isinstance(last, dict):
return str(last.get("content", ""))
return ""
content = last.get("content", "")
return HumanMessage(content=content) if content else None
return None
def _unpack_stream_item(