feat(events): align message events with checkpoint format and add middleware tag injection

- Message events (ai_message, ai_tool_call, tool_result, human_message) now use
  BaseMessage.model_dump() format, matching LangGraph checkpoint values.messages
- on_tool_end extracts tool_call_id/name/status from ToolMessage objects
- on_tool_error now emits tool_result message events with error status
- record_middleware uses middleware:{tag} event_type and middleware category
- Summarization custom events use middleware:summarize category
- TitleMiddleware injects middleware:title tag via get_config() inheritance
- SummarizationMiddleware model bound with middleware:summarize tag
- Worker writes human_message using HumanMessage.model_dump()

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-04 20:52:27 +08:00
parent 2d135aad0f
commit 52e7acafee
6 changed files with 356 additions and 98 deletions
@@ -56,13 +56,15 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
# Prepare keep parameter
keep = config.keep.to_tuple()
# Prepare model parameter
# Prepare model parameter.
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False)
else:
# Use a lightweight model for summarization to save costs
# Falls back to default model if not explicitly specified
model = create_chat_model(thinking_enabled=False)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs
kwargs = {
@@ -1,10 +1,11 @@
"""Middleware for automatic thread title generation."""
import logging
from typing import NotRequired, override
from typing import Any, NotRequired, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config
@@ -100,6 +101,20 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation"
def _get_runnable_config(self) -> dict[str, Any]:
"""Inherit the parent RunnableConfig and add middleware tag.
This ensures RunJournal identifies LLM calls from this middleware
as ``middleware:title`` instead of ``lead_agent``.
"""
try:
parent = get_config()
except Exception:
parent = {}
config = {**parent}
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Synchronously generate a title. Returns state update or None."""
if not self._should_generate_title(state):
@@ -110,7 +125,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = model.invoke(prompt)
response = model.invoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
@@ -130,7 +145,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False)
try:
response = await model.ainvoke(prompt)
response = await model.ainvoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content)
if not title:
title = self._fallback_title(user_msg)
@@ -179,7 +179,8 @@ class RunJournal(BaseCallbackHandler):
},
)
# Message events: only lead_agent gets message-category events
# Message events: only lead_agent gets message-category events.
# Content uses message.model_dump() to align with checkpoint format.
tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent":
resp_meta = getattr(message, "response_metadata", None) or {}
@@ -189,7 +190,7 @@ class RunJournal(BaseCallbackHandler):
self._put(
event_type="ai_tool_call",
category="message",
content=langchain_to_openai_message(message),
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
)
elif isinstance(content, str) and content:
@@ -197,10 +198,10 @@ class RunJournal(BaseCallbackHandler):
self._put(
event_type="ai_message",
category="message",
content={"role": "assistant", "content": content},
content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "stop"},
)
self._last_ai_msg = content[:2000]
self._last_ai_msg = content
self._msg_count += 1
# Token accumulation
@@ -242,45 +243,87 @@ class RunJournal(BaseCallbackHandler):
},
)
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
# Extract fields from ToolMessage object when LangChain provides one.
# LangChain's _format_output wraps tool results into a ToolMessage
# with tool_call_id, name, status, and artifact — more complete than
# what kwargs alone provides.
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
# Use model_dump() for checkpoint-aligned message content.
# Override tool_call_id if it was resolved from cache.
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
# Construct checkpoint-aligned dict when output is a plain string.
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
# Trace event (always)
self._put(
event_type="tool_end",
category="trace",
content=str(output),
content=content_str,
metadata={
"tool_name": tool_name,
"tool_call_id": tool_call_id,
"status": "success",
"status": status,
},
)
# Message event: tool_result
# Message event: tool_result (checkpoint-aligned model_dump format)
self._put(
event_type="tool_result",
category="message",
content={
"role": "tool",
"tool_call_id": tool_call_id or "",
"content": str(output),
},
metadata={"tool_name": tool_name},
content=msg_content,
metadata={"tool_name": tool_name, "status": status},
)
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
# Trace event
self._put(
event_type="tool_error",
category="trace",
content=str(error),
metadata={
"tool_name": kwargs.get("name", ""),
"tool_call_id": kwargs.get("tool_call_id"),
"tool_name": tool_name,
"tool_call_id": tool_call_id,
},
)
# Message event: tool_result with error status (checkpoint-aligned)
msg_content = ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump()
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": "error"},
)
# -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
@@ -298,8 +341,8 @@ class RunJournal(BaseCallbackHandler):
},
)
self._put(
event_type="summary",
category="message",
event_type="middleware:summarize",
category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
)
@@ -366,16 +409,24 @@ class RunJournal(BaseCallbackHandler):
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
def record_middleware(self, name: str, hook: str, action: str, changes: dict) -> None:
"""Record a middleware trace event.
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
"""Record a middleware state-change event.
Called by middleware implementations when they perform a meaningful
state change (e.g., title generation, summarization, HITL approval).
Pure-observation middleware should not call this.
Args:
tag: Short identifier for the middleware (e.g., "title", "summarize",
"guardrail"). Used to form event_type="middleware:{tag}".
name: Full middleware class name.
hook: Lifecycle hook that triggered the action (e.g., "after_model").
action: Specific action performed (e.g., "generate_title").
changes: Dict describing the state changes made.
"""
self._put(
event_type="middleware",
category="trace",
event_type=f"middleware:{tag}",
category="middleware",
content={"name": name, "hook": hook, "action": action, "changes": changes},
)
@@ -67,9 +67,9 @@ async def run_agent(
track_token_usage=getattr(run_events_config, "track_token_usage", True),
)
# Write human_message event
user_input = _extract_user_input(graph_input)
if user_input:
# Write human_message event (model_dump format, aligned with checkpoint)
human_msg = _extract_human_message(graph_input)
if human_msg is not None:
msg_metadata = {}
if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
@@ -78,10 +78,11 @@ async def run_agent(
run_id=run_id,
event_type="human_message",
category="message",
content={"role": "user", "content": user_input},
content=human_msg.model_dump(),
metadata=msg_metadata or None,
)
journal.set_first_human_message(user_input)
content = human_msg.content
journal.set_first_human_message(content if isinstance(content, str) else str(content))
# Track whether "events" was requested but skipped
if "events" in requested_modes:
@@ -282,21 +283,29 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode
def _extract_user_input(graph_input: dict) -> str:
"""Extract user input text from graph_input for event recording."""
def _extract_human_message(graph_input: dict) -> "HumanMessage | None":
"""Extract or construct a HumanMessage from graph_input for event recording.
Returns a LangChain HumanMessage so callers can use .model_dump() to get
the checkpoint-aligned serialization format.
"""
from langchain_core.messages import HumanMessage
messages = graph_input.get("messages")
if not messages:
return ""
# Take the last message (usually the user's input)
return None
last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, str):
if isinstance(last, HumanMessage):
return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"):
content = last.content
return content if isinstance(content, str) else str(content)
return HumanMessage(content=content)
if isinstance(last, dict):
return str(last.get("content", ""))
return ""
content = last.get("content", "")
return HumanMessage(content=content) if content else None
return None
def _unpack_stream_item(