feat(events): align message events with checkpoint format and add middleware tag injection

- Message events (ai_message, ai_tool_call, tool_result, human_message) now use
  BaseMessage.model_dump() format, matching LangGraph checkpoint values.messages
- on_tool_end extracts tool_call_id/name/status from ToolMessage objects
- on_tool_error now emits tool_result message events with error status
- record_middleware uses middleware:{tag} event_type and middleware category
- Summarization custom events use middleware:summarize category
- TitleMiddleware injects middleware:title tag via get_config() inheritance
- SummarizationMiddleware model bound with middleware:summarize tag
- Worker writes human_message using HumanMessage.model_dump()

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-04 20:52:27 +08:00
parent 2d135aad0f
commit 52e7acafee
6 changed files with 356 additions and 98 deletions
@@ -146,8 +146,11 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
)
from unittest.mock import MagicMock
captured: dict[str, object] = {}
fake_model = object()
fake_model = MagicMock()
fake_model.with_config.return_value = fake_model
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
captured["name"] = name
@@ -163,3 +166,4 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
assert captured["name"] == "model-masswork"
assert captured["thinking_enabled"] is False
assert middleware["model"] is fake_model
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
+231 -54
View File
@@ -20,24 +20,32 @@ def journal_setup():
return j, store
def _make_llm_response(content="Hello", usage=None, tool_calls=None):
"""Create a mock LLM response with a message."""
def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None):
"""Create a mock LLM response with a message.
model_dump() returns checkpoint-aligned format matching real AIMessage.
"""
msg = MagicMock()
msg.type = "ai"
msg.content = content
msg.id = f"msg-{id(msg)}"
msg.tool_calls = tool_calls or []
msg.invalid_tool_calls = []
msg.response_metadata = {"model_name": "test-model"}
msg.usage_metadata = usage
# Provide a real model_dump so serialize_lc_object returns a plain dict
# (needed for DB-backed tests where json.dumps must succeed).
msg.additional_kwargs = additional_kwargs or {}
msg.name = None
# model_dump returns checkpoint-aligned format
msg.model_dump.return_value = {
"type": "ai",
"content": content,
"additional_kwargs": additional_kwargs or {},
"response_metadata": {"model_name": "test-model"},
"type": "ai",
"name": None,
"id": msg.id,
"tool_calls": tool_calls or [],
"invalid_tool_calls": [],
"usage_metadata": usage,
"response_metadata": {"model_name": "test-model"},
}
gen = MagicMock()
@@ -71,7 +79,9 @@ class TestLlmCallbacks:
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "ai_message"
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"}
# Content is checkpoint-aligned model_dump format
assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Answer"
@pytest.mark.anyio
async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup):
@@ -211,10 +221,14 @@ class TestCustomEvents:
events = await store.list_events("t1", "r1")
trace = [e for e in events if e["event_type"] == "summarization"]
assert len(trace) == 1
# Summarization goes to middleware category, not message
mw_events = [e for e in events if e["event_type"] == "middleware:summarize"]
assert len(mw_events) == 1
assert mw_events[0]["category"] == "middleware"
assert mw_events[0]["content"] == {"role": "system", "content": "Context was summarized."}
# No message events from summarization
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "summary"
assert messages[0]["content"] == {"role": "system", "content": "Context was summarized."}
assert len(messages) == 0
@pytest.mark.anyio
async def test_non_summarization_custom_event(self, journal_setup):
@@ -375,8 +389,11 @@ class TestDbBackedLifecycle:
record = await mgr.create("t1", "lead_agent")
run_id = record.run_id
# Write human_message
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content={"role": "user", "content": "Hello DB"})
# Write human_message (checkpoint-aligned format)
from langchain_core.messages import HumanMessage
human_msg = HumanMessage(content="Hello DB")
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content=human_msg.model_dump())
# Simulate journal
journal = RunJournal(run_id, "t1", event_store, flush_threshold=100)
@@ -406,12 +423,14 @@ class TestDbBackedLifecycle:
assert row["status"] == "success"
assert row["total_tokens"] == 15
# Verify messages from DB
# Verify messages from DB (checkpoint-aligned format)
messages = await event_store.list_messages("t1")
assert len(messages) == 2
assert messages[0]["event_type"] == "human_message"
assert messages[0]["content"]["type"] == "human"
assert messages[1]["event_type"] == "ai_message"
assert messages[1]["content"] == {"role": "assistant", "content": "DB response"}
assert messages[1]["content"]["type"] == "ai"
assert messages[1]["content"]["content"] == "DB response"
# Verify events from DB
events = await event_store.list_events("t1", run_id)
@@ -560,38 +579,45 @@ class TestDictContent:
await close_engine()
class TestOpenAIHumanMessage:
class TestCheckpointAlignedHumanMessage:
@pytest.mark.anyio
async def test_human_message_openai_format(self):
async def test_human_message_checkpoint_format(self):
"""human_message content uses model_dump() checkpoint format."""
from langchain_core.messages import HumanMessage
store = MemoryRunEventStore()
human_msg = HumanMessage(content="What is AI?")
await store.put(
thread_id="t1",
run_id="r1",
event_type="human_message",
category="message",
content={"role": "user", "content": "What is AI?"},
content=human_msg.model_dump(),
metadata={"message_id": "msg_001"},
)
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"] == {"role": "user", "content": "What is AI?"}
assert messages[0]["content"]["role"] == "user"
assert messages[0]["content"]["type"] == "human"
assert messages[0]["content"]["content"] == "What is AI?"
class TestOpenAIMessageFormat:
class TestCheckpointAlignedMessageFormat:
@pytest.mark.anyio
async def test_ai_message_openai_format(self, journal_setup):
"""ai_message content should be OpenAI assistant message dict."""
async def test_ai_message_checkpoint_format(self, journal_setup):
"""ai_message content should be checkpoint-aligned model_dump dict."""
j, store = journal_setup
j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), tags=["lead_agent"])
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"}
assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Answer"
assert "response_metadata" in messages[0]["content"]
assert "additional_kwargs" in messages[0]["content"]
@pytest.mark.anyio
async def test_ai_tool_call_event(self, journal_setup):
"""LLM response with tool_calls should produce ai_tool_call message event."""
"""LLM response with tool_calls should produce ai_tool_call with model_dump content."""
j, store = journal_setup
tool_calls = [{"id": "call_1", "name": "search", "args": {"query": "test"}}]
j.on_llm_end(
@@ -603,13 +629,12 @@ class TestOpenAIMessageFormat:
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "ai_tool_call"
assert messages[0]["content"]["role"] == "assistant"
assert messages[0]["content"]["type"] == "ai"
assert messages[0]["content"]["content"] == "Let me search"
assert len(messages[0]["content"]["tool_calls"]) == 1
tc = messages[0]["content"]["tool_calls"][0]
assert tc["id"] == "call_1"
assert tc["type"] == "function"
assert tc["function"]["name"] == "search"
assert tc["name"] == "search"
@pytest.mark.anyio
async def test_ai_tool_call_only_from_lead_agent(self, journal_setup):
@@ -637,11 +662,11 @@ class TestToolResultMessage:
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["event_type"] == "tool_result"
assert messages[0]["content"] == {
"role": "tool",
"tool_call_id": "call_abc",
"content": "search results here",
}
# Content is checkpoint-aligned model_dump format
assert messages[0]["content"]["type"] == "tool"
assert messages[0]["content"]["tool_call_id"] == "call_abc"
assert messages[0]["content"]["content"] == "search results here"
assert messages[0]["content"]["name"] == "web_search"
@pytest.mark.anyio
async def test_tool_result_missing_tool_call_id(self, journal_setup):
@@ -652,15 +677,128 @@ class TestToolResultMessage:
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"]["role"] == "tool"
assert messages[0]["content"]["type"] == "tool"
@pytest.mark.anyio
async def test_tool_error_no_tool_result_message(self, journal_setup):
async def test_tool_end_extracts_from_tool_message_object(self, journal_setup):
"""When LangChain passes a ToolMessage object as output, extract fields from it."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
tool_msg = ToolMessage(
content="search results",
tool_call_id="call_from_obj",
name="web_search",
status="success",
)
j.on_tool_end(tool_msg, run_id=run_id)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"]["type"] == "tool"
assert messages[0]["content"]["tool_call_id"] == "call_from_obj"
assert messages[0]["content"]["content"] == "search results"
assert messages[0]["content"]["name"] == "web_search"
assert messages[0]["metadata"]["tool_name"] == "web_search"
assert messages[0]["metadata"]["status"] == "success"
events = await store.list_events("t1", "r1")
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
assert tool_end["metadata"]["tool_call_id"] == "call_from_obj"
assert tool_end["metadata"]["tool_name"] == "web_search"
@pytest.mark.anyio
async def test_tool_message_object_overrides_kwargs(self, journal_setup):
"""ToolMessage object fields take priority over kwargs."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
tool_msg = ToolMessage(
content="result",
tool_call_id="call_obj",
name="tool_a",
status="success",
)
# Pass different values in kwargs — ToolMessage should win
j.on_tool_end(tool_msg, run_id=run_id, name="tool_b", tool_call_id="call_kwarg")
await j.flush()
messages = await store.list_messages("t1")
assert messages[0]["content"]["tool_call_id"] == "call_obj"
assert messages[0]["content"]["name"] == "tool_a"
assert messages[0]["metadata"]["tool_name"] == "tool_a"
@pytest.mark.anyio
async def test_tool_message_error_status(self, journal_setup):
"""ToolMessage with status='error' propagates status to metadata."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
tool_msg = ToolMessage(
content="something went wrong",
tool_call_id="call_err",
name="web_fetch",
status="error",
)
j.on_tool_end(tool_msg, run_id=run_id)
await j.flush()
events = await store.list_events("t1", "r1")
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
assert tool_end["metadata"]["status"] == "error"
messages = await store.list_messages("t1")
assert messages[0]["content"]["status"] == "error"
assert messages[0]["metadata"]["status"] == "error"
@pytest.mark.anyio
async def test_tool_message_fallback_to_cache(self, journal_setup):
"""If ToolMessage has empty tool_call_id, fall back to cache from on_tool_start."""
from langchain_core.messages import ToolMessage
j, store = journal_setup
run_id = uuid4()
j.on_tool_start({"name": "bash"}, "ls", run_id=run_id, tool_call_id="call_cached")
tool_msg = ToolMessage(
content="file list",
tool_call_id="",
name="bash",
)
j.on_tool_end(tool_msg, run_id=run_id)
await j.flush()
messages = await store.list_messages("t1")
assert messages[0]["content"]["tool_call_id"] == "call_cached"
@pytest.mark.anyio
async def test_tool_error_produces_tool_result_message(self, journal_setup):
j, store = journal_setup
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1")
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
assert len(messages) == 1
assert messages[0]["event_type"] == "tool_result"
assert messages[0]["content"]["type"] == "tool"
assert messages[0]["content"]["tool_call_id"] == "call_1"
assert "timeout" in messages[0]["content"]["content"]
assert messages[0]["content"]["status"] == "error"
assert messages[0]["metadata"]["status"] == "error"
@pytest.mark.anyio
async def test_tool_error_uses_cached_tool_call_id(self, journal_setup):
"""on_tool_error should fall back to cached tool_call_id from on_tool_start."""
j, store = journal_setup
run_id = uuid4()
j.on_tool_start({"name": "web_fetch"}, "url", run_id=run_id, tool_call_id="call_cached")
j.on_tool_error(TimeoutError("timeout"), run_id=run_id, name="web_fetch")
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 1
assert messages[0]["content"]["tool_call_id"] == "call_cached"
def _make_base_messages():
@@ -745,11 +883,12 @@ class TestLlmRequestResponse:
assert not any(e["event_type"] == "llm_start" for e in events)
class TestMiddlewareTrace:
class TestMiddlewareEvents:
@pytest.mark.anyio
async def test_record_middleware(self, journal_setup):
async def test_record_middleware_uses_middleware_category(self, journal_setup):
j, store = journal_setup
j.record_middleware(
"title",
name="TitleMiddleware",
hook="after_model",
action="generate_title",
@@ -757,27 +896,60 @@ class TestMiddlewareTrace:
)
await j.flush()
events = await store.list_events("t1", "r1")
mw_events = [e for e in events if e["event_type"] == "middleware"]
mw_events = [e for e in events if e["event_type"] == "middleware:title"]
assert len(mw_events) == 1
assert mw_events[0]["category"] == "trace"
assert mw_events[0]["category"] == "middleware"
assert mw_events[0]["content"]["name"] == "TitleMiddleware"
assert mw_events[0]["content"]["hook"] == "after_model"
assert mw_events[0]["content"]["action"] == "generate_title"
assert mw_events[0]["content"]["changes"]["title"] == "Test Title"
@pytest.mark.anyio
async def test_middleware_events_not_in_messages(self, journal_setup):
"""Middleware events should not appear in list_messages()."""
j, store = journal_setup
j.record_middleware(
"title",
name="TitleMiddleware",
hook="after_model",
action="generate_title",
changes={"title": "Test"},
)
await j.flush()
messages = await store.list_messages("t1")
assert len(messages) == 0
@pytest.mark.anyio
async def test_middleware_tag_variants(self, journal_setup):
"""Different middleware tags produce distinct event_types."""
j, store = journal_setup
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={})
j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={})
await j.flush()
events = await store.list_events("t1", "r1")
event_types = {e["event_type"] for e in events}
assert "middleware:title" in event_types
assert "middleware:guardrail" in event_types
class TestFullRunSequence:
@pytest.mark.anyio
async def test_complete_run_event_sequence(self):
"""Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply."""
"""Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply.
All message events use checkpoint-aligned model_dump format.
"""
from langchain_core.messages import HumanMessage
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, flush_threshold=100)
# 1. Human message (written by worker, not journal)
# 1. Human message (written by worker, using model_dump format)
human_msg = HumanMessage(content="Search for quantum computing")
await store.put(
thread_id="t1", run_id="r1",
event_type="human_message", category="message",
content={"role": "user", "content": "Search for quantum computing"},
content=human_msg.model_dump(),
)
j.set_first_human_message("Search for quantum computing")
@@ -805,7 +977,7 @@ class TestFullRunSequence:
j.on_tool_end("Quantum computing results...", run_id=tool_id, name="web_search", tool_call_id="call_1")
# 5. Middleware: title generation
j.record_middleware("TitleMiddleware", "after_model", "generate_title", {"title": "Quantum Computing"})
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={"title": "Quantum Computing"})
# 6. Second LLM call -> final reply
llm2_id = uuid4()
@@ -824,18 +996,19 @@ class TestFullRunSequence:
await asyncio.sleep(0.05)
await j.flush()
# Verify message sequence (what gets exported for training)
# Verify message sequence
messages = await store.list_messages("t1")
msg_types = [m["event_type"] for m in messages]
assert msg_types == ["human_message", "ai_tool_call", "tool_result", "ai_message"]
# Verify message content format
assert messages[0]["content"]["role"] == "user"
assert messages[1]["content"]["role"] == "assistant"
# Verify checkpoint-aligned format: all messages use "type" not "role"
assert messages[0]["content"]["type"] == "human"
assert messages[0]["content"]["content"] == "Search for quantum computing"
assert messages[1]["content"]["type"] == "ai"
assert "tool_calls" in messages[1]["content"]
assert messages[2]["content"]["role"] == "tool"
assert messages[2]["content"]["type"] == "tool"
assert messages[2]["content"]["tool_call_id"] == "call_1"
assert messages[3]["content"]["role"] == "assistant"
assert messages[3]["content"]["type"] == "ai"
assert messages[3]["content"]["content"] == "Here are the results about quantum computing..."
# Verify trace events
@@ -845,10 +1018,14 @@ class TestFullRunSequence:
assert "llm_response" in trace_types
assert "tool_start" in trace_types
assert "tool_end" in trace_types
assert "middleware" in trace_types
assert "llm_start" not in trace_types
assert "llm_end" not in trace_types
# Verify middleware events are in their own category
mw_events = [e for e in events if e["category"] == "middleware"]
assert len(mw_events) == 1
assert mw_events[0]["event_type"] == "middleware:title"
# Verify token accumulation
data = j.get_completion_data()
assert data["total_tokens"] == 420 # 120 + 300
@@ -857,7 +1034,7 @@ class TestFullRunSequence:
assert data["message_count"] == 1 # only final ai_message counts
assert data["last_ai_message"] == "Here are the results about quantum computing..."
# Verify training data export is trivial
training_messages = [m["content"] for m in messages]
assert all(isinstance(m, dict) for m in training_messages)
assert all("role" in m for m in training_messages)
# Verify all message contents are checkpoint-aligned dicts with "type" field
for m in messages:
assert isinstance(m["content"], dict)
assert "type" in m["content"]