mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 15:36:48 +00:00
feat(events): align message events with checkpoint format and add middleware tag injection
- Message events (ai_message, ai_tool_call, tool_result, human_message) now use
BaseMessage.model_dump() format, matching LangGraph checkpoint values.messages
- on_tool_end extracts tool_call_id/name/status from ToolMessage objects
- on_tool_error now emits tool_result message events with error status
- record_middleware uses middleware:{tag} event_type and middleware category
- Summarization custom events use middleware:summarize category
- TitleMiddleware injects middleware:title tag via get_config() inheritance
- SummarizationMiddleware model bound with middleware:summarize tag
- Worker writes human_message using HumanMessage.model_dump()
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -56,13 +56,15 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
|
||||
# Prepare keep parameter
|
||||
keep = config.keep.to_tuple()
|
||||
|
||||
# Prepare model parameter
|
||||
# Prepare model parameter.
|
||||
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
|
||||
# as middleware rather than lead_agent (SummarizationMiddleware is a
|
||||
# LangChain built-in, so we tag the model at creation time).
|
||||
if config.model_name:
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
else:
|
||||
# Use a lightweight model for summarization to save costs
|
||||
# Falls back to default model if not explicitly specified
|
||||
model = create_chat_model(thinking_enabled=False)
|
||||
model = model.with_config(tags=["middleware:summarize"])
|
||||
|
||||
# Prepare kwargs
|
||||
kwargs = {
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
"""Middleware for automatic thread title generation."""
|
||||
|
||||
import logging
|
||||
from typing import NotRequired, override
|
||||
from typing import Any, NotRequired, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langgraph.config import get_config
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
from deerflow.config.title_config import get_title_config
|
||||
@@ -100,6 +101,20 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
return user_msg[:fallback_chars].rstrip() + "..."
|
||||
return user_msg if user_msg else "New Conversation"
|
||||
|
||||
def _get_runnable_config(self) -> dict[str, Any]:
|
||||
"""Inherit the parent RunnableConfig and add middleware tag.
|
||||
|
||||
This ensures RunJournal identifies LLM calls from this middleware
|
||||
as ``middleware:title`` instead of ``lead_agent``.
|
||||
"""
|
||||
try:
|
||||
parent = get_config()
|
||||
except Exception:
|
||||
parent = {}
|
||||
config = {**parent}
|
||||
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
|
||||
return config
|
||||
|
||||
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
|
||||
"""Synchronously generate a title. Returns state update or None."""
|
||||
if not self._should_generate_title(state):
|
||||
@@ -110,7 +125,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
|
||||
try:
|
||||
response = model.invoke(prompt)
|
||||
response = model.invoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
@@ -130,7 +145,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
|
||||
model = create_chat_model(name=config.model_name, thinking_enabled=False)
|
||||
|
||||
try:
|
||||
response = await model.ainvoke(prompt)
|
||||
response = await model.ainvoke(prompt, config=self._get_runnable_config())
|
||||
title = self._parse_title(response.content)
|
||||
if not title:
|
||||
title = self._fallback_title(user_msg)
|
||||
|
||||
@@ -179,7 +179,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
},
|
||||
)
|
||||
|
||||
# Message events: only lead_agent gets message-category events
|
||||
# Message events: only lead_agent gets message-category events.
|
||||
# Content uses message.model_dump() to align with checkpoint format.
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
if caller == "lead_agent":
|
||||
resp_meta = getattr(message, "response_metadata", None) or {}
|
||||
@@ -189,7 +190,7 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._put(
|
||||
event_type="ai_tool_call",
|
||||
category="message",
|
||||
content=langchain_to_openai_message(message),
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "tool_calls"},
|
||||
)
|
||||
elif isinstance(content, str) and content:
|
||||
@@ -197,10 +198,10 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._put(
|
||||
event_type="ai_message",
|
||||
category="message",
|
||||
content={"role": "assistant", "content": content},
|
||||
content=message.model_dump(),
|
||||
metadata={"model_name": model_name, "finish_reason": "stop"},
|
||||
)
|
||||
self._last_ai_msg = content[:2000]
|
||||
self._last_ai_msg = content
|
||||
self._msg_count += 1
|
||||
|
||||
# Token accumulation
|
||||
@@ -242,45 +243,87 @@ class RunJournal(BaseCallbackHandler):
|
||||
},
|
||||
)
|
||||
|
||||
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
# Extract fields from ToolMessage object when LangChain provides one.
|
||||
# LangChain's _format_output wraps tool results into a ToolMessage
|
||||
# with tool_call_id, name, status, and artifact — more complete than
|
||||
# what kwargs alone provides.
|
||||
if isinstance(output, ToolMessage):
|
||||
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = output.name or kwargs.get("name", "")
|
||||
status = getattr(output, "status", "success") or "success"
|
||||
content_str = output.content if isinstance(output.content, str) else str(output.content)
|
||||
# Use model_dump() for checkpoint-aligned message content.
|
||||
# Override tool_call_id if it was resolved from cache.
|
||||
msg_content = output.model_dump()
|
||||
if msg_content.get("tool_call_id") != tool_call_id:
|
||||
msg_content["tool_call_id"] = tool_call_id
|
||||
else:
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
status = "success"
|
||||
content_str = str(output)
|
||||
# Construct checkpoint-aligned dict when output is a plain string.
|
||||
msg_content = ToolMessage(
|
||||
content=content_str,
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status=status,
|
||||
).model_dump()
|
||||
|
||||
# Trace event (always)
|
||||
self._put(
|
||||
event_type="tool_end",
|
||||
category="trace",
|
||||
content=str(output),
|
||||
content=content_str,
|
||||
metadata={
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
"status": "success",
|
||||
"status": status,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result
|
||||
# Message event: tool_result (checkpoint-aligned model_dump format)
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content={
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id or "",
|
||||
"content": str(output),
|
||||
},
|
||||
metadata={"tool_name": tool_name},
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": status},
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
|
||||
tool_name = kwargs.get("name", "")
|
||||
|
||||
# Trace event
|
||||
self._put(
|
||||
event_type="tool_error",
|
||||
category="trace",
|
||||
content=str(error),
|
||||
metadata={
|
||||
"tool_name": kwargs.get("name", ""),
|
||||
"tool_call_id": kwargs.get("tool_call_id"),
|
||||
"tool_name": tool_name,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
)
|
||||
|
||||
# Message event: tool_result with error status (checkpoint-aligned)
|
||||
msg_content = ToolMessage(
|
||||
content=str(error),
|
||||
tool_call_id=tool_call_id or "",
|
||||
name=tool_name,
|
||||
status="error",
|
||||
).model_dump()
|
||||
self._put(
|
||||
event_type="tool_result",
|
||||
category="message",
|
||||
content=msg_content,
|
||||
metadata={"tool_name": tool_name, "status": "error"},
|
||||
)
|
||||
|
||||
# -- Custom event callback --
|
||||
|
||||
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
@@ -298,8 +341,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
},
|
||||
)
|
||||
self._put(
|
||||
event_type="summary",
|
||||
category="message",
|
||||
event_type="middleware:summarize",
|
||||
category="middleware",
|
||||
content={"role": "system", "content": data_dict.get("summary", "")},
|
||||
metadata={"replaced_count": data_dict.get("replaced_count", 0)},
|
||||
)
|
||||
@@ -366,16 +409,24 @@ class RunJournal(BaseCallbackHandler):
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
|
||||
def record_middleware(self, name: str, hook: str, action: str, changes: dict) -> None:
|
||||
"""Record a middleware trace event.
|
||||
def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
|
||||
"""Record a middleware state-change event.
|
||||
|
||||
Called by middleware implementations when they perform a meaningful
|
||||
state change (e.g., title generation, summarization, HITL approval).
|
||||
Pure-observation middleware should not call this.
|
||||
|
||||
Args:
|
||||
tag: Short identifier for the middleware (e.g., "title", "summarize",
|
||||
"guardrail"). Used to form event_type="middleware:{tag}".
|
||||
name: Full middleware class name.
|
||||
hook: Lifecycle hook that triggered the action (e.g., "after_model").
|
||||
action: Specific action performed (e.g., "generate_title").
|
||||
changes: Dict describing the state changes made.
|
||||
"""
|
||||
self._put(
|
||||
event_type="middleware",
|
||||
category="trace",
|
||||
event_type=f"middleware:{tag}",
|
||||
category="middleware",
|
||||
content={"name": name, "hook": hook, "action": action, "changes": changes},
|
||||
)
|
||||
|
||||
|
||||
@@ -67,9 +67,9 @@ async def run_agent(
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
)
|
||||
|
||||
# Write human_message event
|
||||
user_input = _extract_user_input(graph_input)
|
||||
if user_input:
|
||||
# Write human_message event (model_dump format, aligned with checkpoint)
|
||||
human_msg = _extract_human_message(graph_input)
|
||||
if human_msg is not None:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
@@ -78,10 +78,11 @@ async def run_agent(
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content={"role": "user", "content": user_input},
|
||||
content=human_msg.model_dump(),
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
journal.set_first_human_message(user_input)
|
||||
content = human_msg.content
|
||||
journal.set_first_human_message(content if isinstance(content, str) else str(content))
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
@@ -282,21 +283,29 @@ def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
return mode
|
||||
|
||||
|
||||
def _extract_user_input(graph_input: dict) -> str:
|
||||
"""Extract user input text from graph_input for event recording."""
|
||||
def _extract_human_message(graph_input: dict) -> "HumanMessage | None":
|
||||
"""Extract or construct a HumanMessage from graph_input for event recording.
|
||||
|
||||
Returns a LangChain HumanMessage so callers can use .model_dump() to get
|
||||
the checkpoint-aligned serialization format.
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
messages = graph_input.get("messages")
|
||||
if not messages:
|
||||
return ""
|
||||
# Take the last message (usually the user's input)
|
||||
return None
|
||||
last = messages[-1] if isinstance(messages, list) else messages
|
||||
if isinstance(last, str):
|
||||
if isinstance(last, HumanMessage):
|
||||
return last
|
||||
if isinstance(last, str):
|
||||
return HumanMessage(content=last) if last else None
|
||||
if hasattr(last, "content"):
|
||||
content = last.content
|
||||
return content if isinstance(content, str) else str(content)
|
||||
return HumanMessage(content=content)
|
||||
if isinstance(last, dict):
|
||||
return str(last.get("content", ""))
|
||||
return ""
|
||||
content = last.get("content", "")
|
||||
return HumanMessage(content=content) if content else None
|
||||
return None
|
||||
|
||||
|
||||
def _unpack_stream_item(
|
||||
|
||||
Reference in New Issue
Block a user