feat(events): align message events with checkpoint format and add middleware tag injection

- Message events (ai_message, ai_tool_call, tool_result, human_message) now use
  BaseMessage.model_dump() format, matching LangGraph checkpoint values.messages
- on_tool_end extracts tool_call_id/name/status from ToolMessage objects
- on_tool_error now emits tool_result message events with error status
- record_middleware uses middleware:{tag} event_type and middleware category
- Summarization custom events use middleware:summarize category
- TitleMiddleware injects middleware:title tag via get_config() inheritance
- SummarizationMiddleware model bound with middleware:summarize tag
- Worker writes human_message using HumanMessage.model_dump()

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-04 20:52:27 +08:00
parent 2d135aad0f
commit 52e7acafee
6 changed files with 356 additions and 98 deletions
@@ -56,13 +56,15 @@ def _create_summarization_middleware() -> SummarizationMiddleware | None:
# Prepare keep parameter # Prepare keep parameter
keep = config.keep.to_tuple() keep = config.keep.to_tuple()
# Prepare model parameter # Prepare model parameter.
# Bind "middleware:summarize" tag so RunJournal identifies these LLM calls
# as middleware rather than lead_agent (SummarizationMiddleware is a
# LangChain built-in, so we tag the model at creation time).
if config.model_name: if config.model_name:
model = create_chat_model(name=config.model_name, thinking_enabled=False) model = create_chat_model(name=config.model_name, thinking_enabled=False)
else: else:
# Use a lightweight model for summarization to save costs
# Falls back to default model if not explicitly specified
model = create_chat_model(thinking_enabled=False) model = create_chat_model(thinking_enabled=False)
model = model.with_config(tags=["middleware:summarize"])
# Prepare kwargs # Prepare kwargs
kwargs = { kwargs = {
@@ -1,10 +1,11 @@
"""Middleware for automatic thread title generation.""" """Middleware for automatic thread title generation."""
import logging import logging
from typing import NotRequired, override from typing import Any, NotRequired, override
from langchain.agents import AgentState from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware from langchain.agents.middleware import AgentMiddleware
from langgraph.config import get_config
from langgraph.runtime import Runtime from langgraph.runtime import Runtime
from deerflow.config.title_config import get_title_config from deerflow.config.title_config import get_title_config
@@ -100,6 +101,20 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
return user_msg[:fallback_chars].rstrip() + "..." return user_msg[:fallback_chars].rstrip() + "..."
return user_msg if user_msg else "New Conversation" return user_msg if user_msg else "New Conversation"
def _get_runnable_config(self) -> dict[str, Any]:
"""Inherit the parent RunnableConfig and add middleware tag.
This ensures RunJournal identifies LLM calls from this middleware
as ``middleware:title`` instead of ``lead_agent``.
"""
try:
parent = get_config()
except Exception:
parent = {}
config = {**parent}
config["tags"] = [*(config.get("tags") or []), "middleware:title"]
return config
def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None: def _generate_title_result(self, state: TitleMiddlewareState) -> dict | None:
"""Synchronously generate a title. Returns state update or None.""" """Synchronously generate a title. Returns state update or None."""
if not self._should_generate_title(state): if not self._should_generate_title(state):
@@ -110,7 +125,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False) model = create_chat_model(name=config.model_name, thinking_enabled=False)
try: try:
response = model.invoke(prompt) response = model.invoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content) title = self._parse_title(response.content)
if not title: if not title:
title = self._fallback_title(user_msg) title = self._fallback_title(user_msg)
@@ -130,7 +145,7 @@ class TitleMiddleware(AgentMiddleware[TitleMiddlewareState]):
model = create_chat_model(name=config.model_name, thinking_enabled=False) model = create_chat_model(name=config.model_name, thinking_enabled=False)
try: try:
response = await model.ainvoke(prompt) response = await model.ainvoke(prompt, config=self._get_runnable_config())
title = self._parse_title(response.content) title = self._parse_title(response.content)
if not title: if not title:
title = self._fallback_title(user_msg) title = self._fallback_title(user_msg)
@@ -179,7 +179,8 @@ class RunJournal(BaseCallbackHandler):
}, },
) )
# Message events: only lead_agent gets message-category events # Message events: only lead_agent gets message-category events.
# Content uses message.model_dump() to align with checkpoint format.
tool_calls = getattr(message, "tool_calls", None) or [] tool_calls = getattr(message, "tool_calls", None) or []
if caller == "lead_agent": if caller == "lead_agent":
resp_meta = getattr(message, "response_metadata", None) or {} resp_meta = getattr(message, "response_metadata", None) or {}
@@ -189,7 +190,7 @@ class RunJournal(BaseCallbackHandler):
self._put( self._put(
event_type="ai_tool_call", event_type="ai_tool_call",
category="message", category="message",
content=langchain_to_openai_message(message), content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "tool_calls"}, metadata={"model_name": model_name, "finish_reason": "tool_calls"},
) )
elif isinstance(content, str) and content: elif isinstance(content, str) and content:
@@ -197,10 +198,10 @@ class RunJournal(BaseCallbackHandler):
self._put( self._put(
event_type="ai_message", event_type="ai_message",
category="message", category="message",
content={"role": "assistant", "content": content}, content=message.model_dump(),
metadata={"model_name": model_name, "finish_reason": "stop"}, metadata={"model_name": model_name, "finish_reason": "stop"},
) )
self._last_ai_msg = content[:2000] self._last_ai_msg = content
self._msg_count += 1 self._msg_count += 1
# Token accumulation # Token accumulation
@@ -242,45 +243,87 @@ class RunJournal(BaseCallbackHandler):
}, },
) )
def on_tool_end(self, output: str, *, run_id: UUID, **kwargs: Any) -> None: def on_tool_end(self, output: Any, *, run_id: UUID, **kwargs: Any) -> None:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None) from langchain_core.messages import ToolMessage
tool_name = kwargs.get("name", "")
# Extract fields from ToolMessage object when LangChain provides one.
# LangChain's _format_output wraps tool results into a ToolMessage
# with tool_call_id, name, status, and artifact — more complete than
# what kwargs alone provides.
if isinstance(output, ToolMessage):
tool_call_id = output.tool_call_id or kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = output.name or kwargs.get("name", "")
status = getattr(output, "status", "success") or "success"
content_str = output.content if isinstance(output.content, str) else str(output.content)
# Use model_dump() for checkpoint-aligned message content.
# Override tool_call_id if it was resolved from cache.
msg_content = output.model_dump()
if msg_content.get("tool_call_id") != tool_call_id:
msg_content["tool_call_id"] = tool_call_id
else:
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
status = "success"
content_str = str(output)
# Construct checkpoint-aligned dict when output is a plain string.
msg_content = ToolMessage(
content=content_str,
tool_call_id=tool_call_id or "",
name=tool_name,
status=status,
).model_dump()
# Trace event (always) # Trace event (always)
self._put( self._put(
event_type="tool_end", event_type="tool_end",
category="trace", category="trace",
content=str(output), content=content_str,
metadata={ metadata={
"tool_name": tool_name, "tool_name": tool_name,
"tool_call_id": tool_call_id, "tool_call_id": tool_call_id,
"status": "success", "status": status,
}, },
) )
# Message event: tool_result # Message event: tool_result (checkpoint-aligned model_dump format)
self._put( self._put(
event_type="tool_result", event_type="tool_result",
category="message", category="message",
content={ content=msg_content,
"role": "tool", metadata={"tool_name": tool_name, "status": status},
"tool_call_id": tool_call_id or "",
"content": str(output),
},
metadata={"tool_name": tool_name},
) )
def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None: def on_tool_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
from langchain_core.messages import ToolMessage
tool_call_id = kwargs.get("tool_call_id") or self._tool_call_ids.pop(str(run_id), None)
tool_name = kwargs.get("name", "")
# Trace event
self._put( self._put(
event_type="tool_error", event_type="tool_error",
category="trace", category="trace",
content=str(error), content=str(error),
metadata={ metadata={
"tool_name": kwargs.get("name", ""), "tool_name": tool_name,
"tool_call_id": kwargs.get("tool_call_id"), "tool_call_id": tool_call_id,
}, },
) )
# Message event: tool_result with error status (checkpoint-aligned)
msg_content = ToolMessage(
content=str(error),
tool_call_id=tool_call_id or "",
name=tool_name,
status="error",
).model_dump()
self._put(
event_type="tool_result",
category="message",
content=msg_content,
metadata={"tool_name": tool_name, "status": "error"},
)
# -- Custom event callback -- # -- Custom event callback --
def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None: def on_custom_event(self, name: str, data: Any, *, run_id: UUID, **kwargs: Any) -> None:
@@ -298,8 +341,8 @@ class RunJournal(BaseCallbackHandler):
}, },
) )
self._put( self._put(
event_type="summary", event_type="middleware:summarize",
category="message", category="middleware",
content={"role": "system", "content": data_dict.get("summary", "")}, content={"role": "system", "content": data_dict.get("summary", "")},
metadata={"replaced_count": data_dict.get("replaced_count", 0)}, metadata={"replaced_count": data_dict.get("replaced_count", 0)},
) )
@@ -366,16 +409,24 @@ class RunJournal(BaseCallbackHandler):
"""Record the first human message for convenience fields.""" """Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None self._first_human_msg = content[:2000] if content else None
def record_middleware(self, name: str, hook: str, action: str, changes: dict) -> None: def record_middleware(self, tag: str, *, name: str, hook: str, action: str, changes: dict) -> None:
"""Record a middleware trace event. """Record a middleware state-change event.
Called by middleware implementations when they perform a meaningful Called by middleware implementations when they perform a meaningful
state change (e.g., title generation, summarization, HITL approval). state change (e.g., title generation, summarization, HITL approval).
Pure-observation middleware should not call this. Pure-observation middleware should not call this.
Args:
tag: Short identifier for the middleware (e.g., "title", "summarize",
"guardrail"). Used to form event_type="middleware:{tag}".
name: Full middleware class name.
hook: Lifecycle hook that triggered the action (e.g., "after_model").
action: Specific action performed (e.g., "generate_title").
changes: Dict describing the state changes made.
""" """
self._put( self._put(
event_type="middleware", event_type=f"middleware:{tag}",
category="trace", category="middleware",
content={"name": name, "hook": hook, "action": action, "changes": changes}, content={"name": name, "hook": hook, "action": action, "changes": changes},
) )
@@ -67,9 +67,9 @@ async def run_agent(
track_token_usage=getattr(run_events_config, "track_token_usage", True), track_token_usage=getattr(run_events_config, "track_token_usage", True),
) )
# Write human_message event # Write human_message event (model_dump format, aligned with checkpoint)
user_input = _extract_user_input(graph_input) human_msg = _extract_human_message(graph_input)
if user_input: if human_msg is not None:
msg_metadata = {} msg_metadata = {}
if follow_up_to_run_id: if follow_up_to_run_id:
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
@@ -78,10 +78,11 @@ async def run_agent(
run_id=run_id, run_id=run_id,
event_type="human_message", event_type="human_message",
category="message", category="message",
content={"role": "user", "content": user_input}, content=human_msg.model_dump(),
metadata=msg_metadata or None, metadata=msg_metadata or None,
) )
journal.set_first_human_message(user_input) content = human_msg.content
journal.set_first_human_message(content if isinstance(content, str) else str(content))
# Track whether "events" was requested but skipped # Track whether "events" was requested but skipped
if "events" in requested_modes: if "events" in requested_modes:
@@ -282,21 +283,29 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode return mode
def _extract_user_input(graph_input: dict) -> str: def _extract_human_message(graph_input: dict) -> "HumanMessage | None":
"""Extract user input text from graph_input for event recording.""" """Extract or construct a HumanMessage from graph_input for event recording.
Returns a LangChain HumanMessage so callers can use .model_dump() to get
the checkpoint-aligned serialization format.
"""
from langchain_core.messages import HumanMessage
messages = graph_input.get("messages") messages = graph_input.get("messages")
if not messages: if not messages:
return "" return None
# Take the last message (usually the user's input)
last = messages[-1] if isinstance(messages, list) else messages last = messages[-1] if isinstance(messages, list) else messages
if isinstance(last, str): if isinstance(last, HumanMessage):
return last return last
if isinstance(last, str):
return HumanMessage(content=last) if last else None
if hasattr(last, "content"): if hasattr(last, "content"):
content = last.content content = last.content
return content if isinstance(content, str) else str(content) return HumanMessage(content=content)
if isinstance(last, dict): if isinstance(last, dict):
return str(last.get("content", "")) content = last.get("content", "")
return "" return HumanMessage(content=content) if content else None
return None
def _unpack_stream_item( def _unpack_stream_item(
@@ -146,8 +146,11 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"), lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
) )
from unittest.mock import MagicMock
captured: dict[str, object] = {} captured: dict[str, object] = {}
fake_model = object() fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None): def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
captured["name"] = name captured["name"] = name
@@ -163,3 +166,4 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
assert captured["name"] == "model-masswork" assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False assert captured["thinking_enabled"] is False
assert middleware["model"] is fake_model assert middleware["model"] is fake_model
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
+231 -54
View File
@@ -20,24 +20,32 @@ def journal_setup():
return j, store return j, store
def _make_llm_response(content="Hello", usage=None, tool_calls=None): def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None):
"""Create a mock LLM response with a message.""" """Create a mock LLM response with a message.
model_dump() returns checkpoint-aligned format matching real AIMessage.
"""
msg = MagicMock() msg = MagicMock()
msg.type = "ai" msg.type = "ai"
msg.content = content msg.content = content
msg.id = f"msg-{id(msg)}" msg.id = f"msg-{id(msg)}"
msg.tool_calls = tool_calls or [] msg.tool_calls = tool_calls or []
msg.invalid_tool_calls = []
msg.response_metadata = {"model_name": "test-model"} msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage msg.usage_metadata = usage
# Provide a real model_dump so serialize_lc_object returns a plain dict msg.additional_kwargs = additional_kwargs or {}
# (needed for DB-backed tests where json.dumps must succeed). msg.name = None
# model_dump returns checkpoint-aligned format
msg.model_dump.return_value = { msg.model_dump.return_value = {
"type": "ai",
"content": content, "content": content,
"additional_kwargs": additional_kwargs or {},
"response_metadata": {"model_name": "test-model"},
"type": "ai",
"name": None,
"id": msg.id, "id": msg.id,
"tool_calls": tool_calls or [], "tool_calls": tool_calls or [],
"invalid_tool_calls": [],
"usage_metadata": usage, "usage_metadata": usage,
"response_metadata": {"model_name": "test-model"},
} }
gen = MagicMock() gen = MagicMock()
@@ -71,7 +79,9 @@ class TestLlmCallbacks:
messages = await store.list_messages("t1") messages = await store.list_messages("t1")
assert len(messages) == 1 assert len(messages) == 1
assert messages[0]["event_type"] == "ai_message" assert messages[0]["event_type"] == "ai_message"
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"} # Content is checkpoint-aligned model_dump format
assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Answer"
@pytest.mark.anyio @pytest.mark.anyio
async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup): async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup):
@@ -211,10 +221,14 @@ class TestCustomEvents:
events = await store.list_events("t1", "r1") events = await store.list_events("t1", "r1")
trace = [e for e in events if e["event_type"] == "summarization"] trace = [e for e in events if e["event_type"] == "summarization"]
assert len(trace) == 1 assert len(trace) == 1
# Summarization goes to middleware category, not message
mw_events = [e for e in events if e["event_type"] == "middleware:summarize"]
assert len(mw_events) == 1
assert mw_events[0]["category"] == "middleware"
assert mw_events[0]["content"] == {"role": "system", "content": "Context was summarized."}
# No message events from summarization
messages = await store.list_messages("t1") messages = await store.list_messages("t1")
assert len(messages) == 1 assert len(messages) == 0
assert messages[0]["event_type"] == "summary"
assert messages[0]["content"] == {"role": "system", "content": "Context was summarized."}
@pytest.mark.anyio @pytest.mark.anyio
async def test_non_summarization_custom_event(self, journal_setup): async def test_non_summarization_custom_event(self, journal_setup):
@@ -375,8 +389,11 @@ class TestDbBackedLifecycle:
record = await mgr.create("t1", "lead_agent") record = await mgr.create("t1", "lead_agent")
run_id = record.run_id run_id = record.run_id
# Write human_message # Write human_message (checkpoint-aligned format)
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content={"role": "user", "content": "Hello DB"}) from langchain_core.messages import HumanMessage
human_msg = HumanMessage(content="Hello DB")
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content=human_msg.model_dump())
# Simulate journal # Simulate journal
journal = RunJournal(run_id, "t1", event_store, flush_threshold=100) journal = RunJournal(run_id, "t1", event_store, flush_threshold=100)
@@ -406,12 +423,14 @@ class TestDbBackedLifecycle:
assert row["status"] == "success" assert row["status"] == "success"
assert row["total_tokens"] == 15 assert row["total_tokens"] == 15
# Verify messages from DB # Verify messages from DB (checkpoint-aligned format)
messages = await event_store.list_messages("t1") messages = await event_store.list_messages("t1")
assert len(messages) == 2 assert len(messages) == 2
assert messages[0]["event_type"] == "human_message" assert messages[0]["event_type"] == "human_message"
assert messages[0]["content"]["type"] == "human"
assert messages[1]["event_type"] == "ai_message" assert messages[1]["event_type"] == "ai_message"
assert messages[1]["content"] == {"role": "assistant", "content": "DB response"} assert messages[1]["content"]["type"] == "ai"
assert messages[1]["content"]["content"] == "DB response"
# Verify events from DB # Verify events from DB
events = await event_store.list_events("t1", run_id) events = await event_store.list_events("t1", run_id)
@@ -560,38 +579,45 @@ class TestDictContent:
await close_engine() await close_engine()
class TestOpenAIHumanMessage: class TestCheckpointAlignedHumanMessage:
@pytest.mark.anyio @pytest.mark.anyio
async def test_human_message_openai_format(self): async def test_human_message_checkpoint_format(self):
"""human_message content uses model_dump() checkpoint format."""
from langchain_core.messages import HumanMessage
store = MemoryRunEventStore() store = MemoryRunEventStore()
human_msg = HumanMessage(content="What is AI?")
await store.put( await store.put(
thread_id="t1", thread_id="t1",
run_id="r1", run_id="r1",
event_type="human_message", event_type="human_message",
category="message", category="message",
content={"role": "user", "content": "What is AI?"}, content=human_msg.model_dump(),
metadata={"message_id": "msg_001"}, metadata={"message_id": "msg_001"},
) )
messages = await store.list_messages("t1") messages = await store.list_messages("t1")
assert len(messages) == 1 assert len(messages) == 1
assert messages[0]["content"] == {"role": "user", "content": "What is AI?"} assert messages[0]["content"]["type"] == "human"
assert messages[0]["content"]["role"] == "user" assert messages[0]["content"]["content"] == "What is AI?"
class TestOpenAIMessageFormat: class TestCheckpointAlignedMessageFormat:
@pytest.mark.anyio @pytest.mark.anyio
async def test_ai_message_openai_format(self, journal_setup): async def test_ai_message_checkpoint_format(self, journal_setup):
"""ai_message content should be OpenAI assistant message dict.""" """ai_message content should be checkpoint-aligned model_dump dict."""
j, store = journal_setup j, store = journal_setup
j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), tags=["lead_agent"]) j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), tags=["lead_agent"])
await j.flush() await j.flush()
messages = await store.list_messages("t1") messages = await store.list_messages("t1")
assert len(messages) == 1 assert len(messages) == 1
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"} assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Answer"
assert "response_metadata" in messages[0]["content"]
assert "additional_kwargs" in messages[0]["content"]
@pytest.mark.anyio @pytest.mark.anyio
async def test_ai_tool_call_event(self, journal_setup): async def test_ai_tool_call_event(self, journal_setup):
"""LLM response with tool_calls should produce ai_tool_call message event.""" """LLM response with tool_calls should produce ai_tool_call with model_dump content."""
j, store = journal_setup j, store = journal_setup
tool_calls = [{"id": "call_1", "name": "search", "args": {"query": "test"}}] tool_calls = [{"id": "call_1", "name": "search", "args": {"query": "test"}}]
j.on_llm_end( j.on_llm_end(
@@ -603,13 +629,12 @@ class TestOpenAIMessageFormat:
messages = await store.list_messages("t1") messages = await store.list_messages("t1")
assert len(messages) == 1 assert len(messages) == 1
assert messages[0]["event_type"] == "ai_tool_call" assert messages[0]["event_type"] == "ai_tool_call"
assert messages[0]["content"]["role"] == "assistant" assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Let me search" assert messages[0]["content"]["content"] == "Let me search"
assert len(messages[0]["content"]["tool_calls"]) == 1 assert len(messages[0]["content"]["tool_calls"]) == 1
tc = messages[0]["content"]["tool_calls"][0] tc = messages[0]["content"]["tool_calls"][0]
assert tc["id"] == "call_1" assert tc["id"] == "call_1"
assert tc["type"] == "function" assert tc["name"] == "search"
assert tc["function"]["name"] == "search"
@pytest.mark.anyio @pytest.mark.anyio
async def test_ai_tool_call_only_from_lead_agent(self, journal_setup): async def test_ai_tool_call_only_from_lead_agent(self, journal_setup):
@@ -637,11 +662,11 @@ class TestToolResultMessage:
messages = await store.list_messages("t1") messages = await store.list_messages("t1")
assert len(messages) == 1 assert len(messages) == 1
assert messages[0]["event_type"] == "tool_result" assert messages[0]["event_type"] == "tool_result"
assert messages[0]["content"] == { # Content is checkpoint-aligned model_dump format
"role": "tool", assert messages[0]["content"]["type"] == "tool"
"tool_call_id": "call_abc", assert messages[0]["content"]["tool_call_id"] == "call_abc"
"content": "search results here", assert messages[0]["content"]["content"] == "search results here"
} assert messages[0]["content"]["name"] == "web_search"
@pytest.mark.anyio @pytest.mark.anyio
async def test_tool_result_missing_tool_call_id(self, journal_setup): async def test_tool_result_missing_tool_call_id(self, journal_setup):
@@ -652,15 +677,128 @@ class TestToolResultMessage:
await j.flush() await j.flush()
messages = await store.list_messages("t1") messages = await store.list_messages("t1")
assert len(messages) == 1 assert len(messages) == 1
assert messages[0]["content"]["role"] == "tool" assert messages[0]["content"]["type"] == "tool"
@pytest.mark.anyio @pytest.mark.anyio
async def test_tool_error_no_tool_result_message(self, journal_setup): async def test_tool_end_extracts_from_tool_message_object(self, journal_setup):
"""When LangChain passes a ToolMessage object as output, extract fields from it."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
tool_msg = ToolMessage(
content="search results",
tool_call_id="call_from_obj",
name="web_search",
status="success",
)
j.on_tool_end(tool_msg, run_id=run_id)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"]["type"] == "tool"
assert messages[0]["content"]["tool_call_id"] == "call_from_obj"
assert messages[0]["content"]["content"] == "search results"
assert messages[0]["content"]["name"] == "web_search"
assert messages[0]["metadata"]["tool_name"] == "web_search"
assert messages[0]["metadata"]["status"] == "success"
events = await store.list_events("t1", "r1")
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
assert tool_end["metadata"]["tool_call_id"] == "call_from_obj"
assert tool_end["metadata"]["tool_name"] == "web_search"
@pytest.mark.anyio
async def test_tool_message_object_overrides_kwargs(self, journal_setup):
"""ToolMessage object fields take priority over kwargs."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
tool_msg = ToolMessage(
content="result",
tool_call_id="call_obj",
name="tool_a",
status="success",
)
# Pass different values in kwargs — ToolMessage should win
j.on_tool_end(tool_msg, run_id=run_id, name="tool_b", tool_call_id="call_kwarg")
await j.flush()
messages = await store.list_messages("t1")
assert messages[0]["content"]["tool_call_id"] == "call_obj"
assert messages[0]["content"]["name"] == "tool_a"
assert messages[0]["metadata"]["tool_name"] == "tool_a"
@pytest.mark.anyio
async def test_tool_message_error_status(self, journal_setup):
"""ToolMessage with status='error' propagates status to metadata."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
tool_msg = ToolMessage(
content="something went wrong",
tool_call_id="call_err",
name="web_fetch",
status="error",
)
j.on_tool_end(tool_msg, run_id=run_id)
await j.flush()
events = await store.list_events("t1", "r1")
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
assert tool_end["metadata"]["status"] == "error"
messages = await store.list_messages("t1")
assert messages[0]["content"]["status"] == "error"
assert messages[0]["metadata"]["status"] == "error"
@pytest.mark.anyio
async def test_tool_message_fallback_to_cache(self, journal_setup):
"""If ToolMessage has empty tool_call_id, fall back to cache from on_tool_start."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
j.on_tool_start({"name": "bash"}, "ls", run_id=run_id, tool_call_id="call_cached")
tool_msg = ToolMessage(
content="file list",
tool_call_id="",
name="bash",
)
j.on_tool_end(tool_msg, run_id=run_id)
await j.flush()
messages = await store.list_messages("t1")
assert messages[0]["content"]["tool_call_id"] == "call_cached"
@pytest.mark.anyio
async def test_tool_error_produces_tool_result_message(self, journal_setup):
j, store = journal_setup j, store = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1") j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1")
await j.flush() await j.flush()
messages = await store.list_messages("t1") messages = await store.list_messages("t1")
assert len(messages) == 0 assert len(messages) == 1
assert messages[0]["event_type"] == "tool_result"
assert messages[0]["content"]["type"] == "tool"
assert messages[0]["content"]["tool_call_id"] == "call_1"
assert "timeout" in messages[0]["content"]["content"]
assert messages[0]["content"]["status"] == "error"
assert messages[0]["metadata"]["status"] == "error"
@pytest.mark.anyio
async def test_tool_error_uses_cached_tool_call_id(self, journal_setup):
"""on_tool_error should fall back to cached tool_call_id from on_tool_start."""
j, store = journal_setup
run_id = uuid4()
j.on_tool_start({"name": "web_fetch"}, "url", run_id=run_id, tool_call_id="call_cached")
j.on_tool_error(TimeoutError("timeout"), run_id=run_id, name="web_fetch")
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"]["tool_call_id"] == "call_cached"
def _make_base_messages(): def _make_base_messages():
@@ -745,11 +883,12 @@ class TestLlmRequestResponse:
assert not any(e["event_type"] == "llm_start" for e in events) assert not any(e["event_type"] == "llm_start" for e in events)
class TestMiddlewareTrace: class TestMiddlewareEvents:
@pytest.mark.anyio @pytest.mark.anyio
async def test_record_middleware(self, journal_setup): async def test_record_middleware_uses_middleware_category(self, journal_setup):
j, store = journal_setup j, store = journal_setup
j.record_middleware( j.record_middleware(
"title",
name="TitleMiddleware", name="TitleMiddleware",
hook="after_model", hook="after_model",
action="generate_title", action="generate_title",
@@ -757,27 +896,60 @@ class TestMiddlewareTrace:
) )
await j.flush() await j.flush()
events = await store.list_events("t1", "r1") events = await store.list_events("t1", "r1")
mw_events = [e for e in events if e["event_type"] == "middleware"] mw_events = [e for e in events if e["event_type"] == "middleware:title"]
assert len(mw_events) == 1 assert len(mw_events) == 1
assert mw_events[0]["category"] == "trace" assert mw_events[0]["category"] == "middleware"
assert mw_events[0]["content"]["name"] == "TitleMiddleware" assert mw_events[0]["content"]["name"] == "TitleMiddleware"
assert mw_events[0]["content"]["hook"] == "after_model" assert mw_events[0]["content"]["hook"] == "after_model"
assert mw_events[0]["content"]["action"] == "generate_title" assert mw_events[0]["content"]["action"] == "generate_title"
assert mw_events[0]["content"]["changes"]["title"] == "Test Title" assert mw_events[0]["content"]["changes"]["title"] == "Test Title"
@pytest.mark.anyio
async def test_middleware_events_not_in_messages(self, journal_setup):
"""Middleware events should not appear in list_messages()."""
j, store = journal_setup
j.record_middleware(
"title",
name="TitleMiddleware",
hook="after_model",
action="generate_title",
changes={"title": "Test"},
)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
@pytest.mark.anyio
async def test_middleware_tag_variants(self, journal_setup):
"""Different middleware tags produce distinct event_types."""
j, store = journal_setup
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={})
j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={})
await j.flush()
events = await store.list_events("t1", "r1")
event_types = {e["event_type"] for e in events}
assert "middleware:title" in event_types
assert "middleware:guardrail" in event_types
class TestFullRunSequence: class TestFullRunSequence:
@pytest.mark.anyio @pytest.mark.anyio
async def test_complete_run_event_sequence(self): async def test_complete_run_event_sequence(self):
"""Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply.""" """Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply.
All message events use checkpoint-aligned model_dump format.
"""
from langchain_core.messages import HumanMessage
store = MemoryRunEventStore() store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, flush_threshold=100) j = RunJournal("r1", "t1", store, flush_threshold=100)
# 1. Human message (written by worker, not journal) # 1. Human message (written by worker, using model_dump format)
human_msg = HumanMessage(content="Search for quantum computing")
await store.put( await store.put(
thread_id="t1", run_id="r1", thread_id="t1", run_id="r1",
event_type="human_message", category="message", event_type="human_message", category="message",
content={"role": "user", "content": "Search for quantum computing"}, content=human_msg.model_dump(),
) )
j.set_first_human_message("Search for quantum computing") j.set_first_human_message("Search for quantum computing")
@@ -805,7 +977,7 @@ class TestFullRunSequence:
j.on_tool_end("Quantum computing results...", run_id=tool_id, name="web_search", tool_call_id="call_1") j.on_tool_end("Quantum computing results...", run_id=tool_id, name="web_search", tool_call_id="call_1")
# 5. Middleware: title generation # 5. Middleware: title generation
j.record_middleware("TitleMiddleware", "after_model", "generate_title", {"title": "Quantum Computing"}) j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={"title": "Quantum Computing"})
# 6. Second LLM call -> final reply # 6. Second LLM call -> final reply
llm2_id = uuid4() llm2_id = uuid4()
@@ -824,18 +996,19 @@ class TestFullRunSequence:
await asyncio.sleep(0.05) await asyncio.sleep(0.05)
await j.flush() await j.flush()
# Verify message sequence (what gets exported for training) # Verify message sequence
messages = await store.list_messages("t1") messages = await store.list_messages("t1")
msg_types = [m["event_type"] for m in messages] msg_types = [m["event_type"] for m in messages]
assert msg_types == ["human_message", "ai_tool_call", "tool_result", "ai_message"] assert msg_types == ["human_message", "ai_tool_call", "tool_result", "ai_message"]
# Verify message content format # Verify checkpoint-aligned format: all messages use "type" not "role"
assert messages[0]["content"]["role"] == "user" assert messages[0]["content"]["type"] == "human"
assert messages[1]["content"]["role"] == "assistant" assert messages[0]["content"]["content"] == "Search for quantum computing"
assert messages[1]["content"]["type"] == "ai"
assert "tool_calls" in messages[1]["content"] assert "tool_calls" in messages[1]["content"]
assert messages[2]["content"]["role"] == "tool" assert messages[2]["content"]["type"] == "tool"
assert messages[2]["content"]["tool_call_id"] == "call_1" assert messages[2]["content"]["tool_call_id"] == "call_1"
assert messages[3]["content"]["role"] == "assistant" assert messages[3]["content"]["type"] == "ai"
assert messages[3]["content"]["content"] == "Here are the results about quantum computing..." assert messages[3]["content"]["content"] == "Here are the results about quantum computing..."
# Verify trace events # Verify trace events
@@ -845,10 +1018,14 @@ class TestFullRunSequence:
assert "llm_response" in trace_types assert "llm_response" in trace_types
assert "tool_start" in trace_types assert "tool_start" in trace_types
assert "tool_end" in trace_types assert "tool_end" in trace_types
assert "middleware" in trace_types
assert "llm_start" not in trace_types assert "llm_start" not in trace_types
assert "llm_end" not in trace_types assert "llm_end" not in trace_types
# Verify middleware events are in their own category
mw_events = [e for e in events if e["category"] == "middleware"]
assert len(mw_events) == 1
assert mw_events[0]["event_type"] == "middleware:title"
# Verify token accumulation # Verify token accumulation
data = j.get_completion_data() data = j.get_completion_data()
assert data["total_tokens"] == 420 # 120 + 300 assert data["total_tokens"] == 420 # 120 + 300
@@ -857,7 +1034,7 @@ class TestFullRunSequence:
assert data["message_count"] == 1 # only final ai_message counts assert data["message_count"] == 1 # only final ai_message counts
assert data["last_ai_message"] == "Here are the results about quantum computing..." assert data["last_ai_message"] == "Here are the results about quantum computing..."
# Verify training data export is trivial # Verify all message contents are checkpoint-aligned dicts with "type" field
training_messages = [m["content"] for m in messages] for m in messages:
assert all(isinstance(m, dict) for m in training_messages) assert isinstance(m["content"], dict)
assert all("role" in m for m in training_messages) assert "type" in m["content"]