mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 07:26:50 +00:00
feat(events): align message events with checkpoint format and add middleware tag injection
- Message events (ai_message, ai_tool_call, tool_result, human_message) now use
BaseMessage.model_dump() format, matching LangGraph checkpoint values.messages
- on_tool_end extracts tool_call_id/name/status from ToolMessage objects
- on_tool_error now emits tool_result message events with error status
- record_middleware uses middleware:{tag} event_type and middleware category
- Summarization custom events use middleware:summarize category
- TitleMiddleware injects middleware:title tag via get_config() inheritance
- SummarizationMiddleware model bound with middleware:summarize tag
- Worker writes human_message using HumanMessage.model_dump()
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -146,8 +146,11 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
lambda: SummarizationConfig(enabled=True, model_name="model-masswork"),
|
||||
)
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
captured: dict[str, object] = {}
|
||||
fake_model = object()
|
||||
fake_model = MagicMock()
|
||||
fake_model.with_config.return_value = fake_model
|
||||
|
||||
def _fake_create_chat_model(*, name=None, thinking_enabled, reasoning_effort=None):
|
||||
captured["name"] = name
|
||||
@@ -163,3 +166,4 @@ def test_create_summarization_middleware_uses_configured_model_alias(monkeypatch
|
||||
assert captured["name"] == "model-masswork"
|
||||
assert captured["thinking_enabled"] is False
|
||||
assert middleware["model"] is fake_model
|
||||
fake_model.with_config.assert_called_once_with(tags=["middleware:summarize"])
|
||||
|
||||
@@ -20,24 +20,32 @@ def journal_setup():
|
||||
return j, store
|
||||
|
||||
|
||||
def _make_llm_response(content="Hello", usage=None, tool_calls=None):
|
||||
"""Create a mock LLM response with a message."""
|
||||
def _make_llm_response(content="Hello", usage=None, tool_calls=None, additional_kwargs=None):
|
||||
"""Create a mock LLM response with a message.
|
||||
|
||||
model_dump() returns checkpoint-aligned format matching real AIMessage.
|
||||
"""
|
||||
msg = MagicMock()
|
||||
msg.type = "ai"
|
||||
msg.content = content
|
||||
msg.id = f"msg-{id(msg)}"
|
||||
msg.tool_calls = tool_calls or []
|
||||
msg.invalid_tool_calls = []
|
||||
msg.response_metadata = {"model_name": "test-model"}
|
||||
msg.usage_metadata = usage
|
||||
# Provide a real model_dump so serialize_lc_object returns a plain dict
|
||||
# (needed for DB-backed tests where json.dumps must succeed).
|
||||
msg.additional_kwargs = additional_kwargs or {}
|
||||
msg.name = None
|
||||
# model_dump returns checkpoint-aligned format
|
||||
msg.model_dump.return_value = {
|
||||
"type": "ai",
|
||||
"content": content,
|
||||
"additional_kwargs": additional_kwargs or {},
|
||||
"response_metadata": {"model_name": "test-model"},
|
||||
"type": "ai",
|
||||
"name": None,
|
||||
"id": msg.id,
|
||||
"tool_calls": tool_calls or [],
|
||||
"invalid_tool_calls": [],
|
||||
"usage_metadata": usage,
|
||||
"response_metadata": {"model_name": "test-model"},
|
||||
}
|
||||
|
||||
gen = MagicMock()
|
||||
@@ -71,7 +79,9 @@ class TestLlmCallbacks:
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "ai_message"
|
||||
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"}
|
||||
# Content is checkpoint-aligned model_dump format
|
||||
assert messages[0]["content"]["type"] == "ai"
|
||||
assert messages[0]["content"]["content"] == "Answer"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_on_llm_end_with_tool_calls_produces_ai_tool_call(self, journal_setup):
|
||||
@@ -211,10 +221,14 @@ class TestCustomEvents:
|
||||
events = await store.list_events("t1", "r1")
|
||||
trace = [e for e in events if e["event_type"] == "summarization"]
|
||||
assert len(trace) == 1
|
||||
# Summarization goes to middleware category, not message
|
||||
mw_events = [e for e in events if e["event_type"] == "middleware:summarize"]
|
||||
assert len(mw_events) == 1
|
||||
assert mw_events[0]["category"] == "middleware"
|
||||
assert mw_events[0]["content"] == {"role": "system", "content": "Context was summarized."}
|
||||
# No message events from summarization
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "summary"
|
||||
assert messages[0]["content"] == {"role": "system", "content": "Context was summarized."}
|
||||
assert len(messages) == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_non_summarization_custom_event(self, journal_setup):
|
||||
@@ -375,8 +389,11 @@ class TestDbBackedLifecycle:
|
||||
record = await mgr.create("t1", "lead_agent")
|
||||
run_id = record.run_id
|
||||
|
||||
# Write human_message
|
||||
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content={"role": "user", "content": "Hello DB"})
|
||||
# Write human_message (checkpoint-aligned format)
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
human_msg = HumanMessage(content="Hello DB")
|
||||
await event_store.put(thread_id="t1", run_id=run_id, event_type="human_message", category="message", content=human_msg.model_dump())
|
||||
|
||||
# Simulate journal
|
||||
journal = RunJournal(run_id, "t1", event_store, flush_threshold=100)
|
||||
@@ -406,12 +423,14 @@ class TestDbBackedLifecycle:
|
||||
assert row["status"] == "success"
|
||||
assert row["total_tokens"] == 15
|
||||
|
||||
# Verify messages from DB
|
||||
# Verify messages from DB (checkpoint-aligned format)
|
||||
messages = await event_store.list_messages("t1")
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["event_type"] == "human_message"
|
||||
assert messages[0]["content"]["type"] == "human"
|
||||
assert messages[1]["event_type"] == "ai_message"
|
||||
assert messages[1]["content"] == {"role": "assistant", "content": "DB response"}
|
||||
assert messages[1]["content"]["type"] == "ai"
|
||||
assert messages[1]["content"]["content"] == "DB response"
|
||||
|
||||
# Verify events from DB
|
||||
events = await event_store.list_events("t1", run_id)
|
||||
@@ -560,38 +579,45 @@ class TestDictContent:
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestOpenAIHumanMessage:
|
||||
class TestCheckpointAlignedHumanMessage:
|
||||
@pytest.mark.anyio
|
||||
async def test_human_message_openai_format(self):
|
||||
async def test_human_message_checkpoint_format(self):
|
||||
"""human_message content uses model_dump() checkpoint format."""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
human_msg = HumanMessage(content="What is AI?")
|
||||
await store.put(
|
||||
thread_id="t1",
|
||||
run_id="r1",
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content={"role": "user", "content": "What is AI?"},
|
||||
content=human_msg.model_dump(),
|
||||
metadata={"message_id": "msg_001"},
|
||||
)
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == {"role": "user", "content": "What is AI?"}
|
||||
assert messages[0]["content"]["role"] == "user"
|
||||
assert messages[0]["content"]["type"] == "human"
|
||||
assert messages[0]["content"]["content"] == "What is AI?"
|
||||
|
||||
|
||||
class TestOpenAIMessageFormat:
|
||||
class TestCheckpointAlignedMessageFormat:
|
||||
@pytest.mark.anyio
|
||||
async def test_ai_message_openai_format(self, journal_setup):
|
||||
"""ai_message content should be OpenAI assistant message dict."""
|
||||
async def test_ai_message_checkpoint_format(self, journal_setup):
|
||||
"""ai_message content should be checkpoint-aligned model_dump dict."""
|
||||
j, store = journal_setup
|
||||
j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), tags=["lead_agent"])
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"] == {"role": "assistant", "content": "Answer"}
|
||||
assert messages[0]["content"]["type"] == "ai"
|
||||
assert messages[0]["content"]["content"] == "Answer"
|
||||
assert "response_metadata" in messages[0]["content"]
|
||||
assert "additional_kwargs" in messages[0]["content"]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ai_tool_call_event(self, journal_setup):
|
||||
"""LLM response with tool_calls should produce ai_tool_call message event."""
|
||||
"""LLM response with tool_calls should produce ai_tool_call with model_dump content."""
|
||||
j, store = journal_setup
|
||||
tool_calls = [{"id": "call_1", "name": "search", "args": {"query": "test"}}]
|
||||
j.on_llm_end(
|
||||
@@ -603,13 +629,12 @@ class TestOpenAIMessageFormat:
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "ai_tool_call"
|
||||
assert messages[0]["content"]["role"] == "assistant"
|
||||
assert messages[0]["content"]["type"] == "ai"
|
||||
assert messages[0]["content"]["content"] == "Let me search"
|
||||
assert len(messages[0]["content"]["tool_calls"]) == 1
|
||||
tc = messages[0]["content"]["tool_calls"][0]
|
||||
assert tc["id"] == "call_1"
|
||||
assert tc["type"] == "function"
|
||||
assert tc["function"]["name"] == "search"
|
||||
assert tc["name"] == "search"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_ai_tool_call_only_from_lead_agent(self, journal_setup):
|
||||
@@ -637,11 +662,11 @@ class TestToolResultMessage:
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "tool_result"
|
||||
assert messages[0]["content"] == {
|
||||
"role": "tool",
|
||||
"tool_call_id": "call_abc",
|
||||
"content": "search results here",
|
||||
}
|
||||
# Content is checkpoint-aligned model_dump format
|
||||
assert messages[0]["content"]["type"] == "tool"
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_abc"
|
||||
assert messages[0]["content"]["content"] == "search results here"
|
||||
assert messages[0]["content"]["name"] == "web_search"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_result_missing_tool_call_id(self, journal_setup):
|
||||
@@ -652,15 +677,128 @@ class TestToolResultMessage:
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"]["role"] == "tool"
|
||||
assert messages[0]["content"]["type"] == "tool"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_error_no_tool_result_message(self, journal_setup):
|
||||
async def test_tool_end_extracts_from_tool_message_object(self, journal_setup):
|
||||
"""When LangChain passes a ToolMessage object as output, extract fields from it."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
tool_msg = ToolMessage(
|
||||
content="search results",
|
||||
tool_call_id="call_from_obj",
|
||||
name="web_search",
|
||||
status="success",
|
||||
)
|
||||
j.on_tool_end(tool_msg, run_id=run_id)
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"]["type"] == "tool"
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_from_obj"
|
||||
assert messages[0]["content"]["content"] == "search results"
|
||||
assert messages[0]["content"]["name"] == "web_search"
|
||||
assert messages[0]["metadata"]["tool_name"] == "web_search"
|
||||
assert messages[0]["metadata"]["status"] == "success"
|
||||
|
||||
events = await store.list_events("t1", "r1")
|
||||
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
|
||||
assert tool_end["metadata"]["tool_call_id"] == "call_from_obj"
|
||||
assert tool_end["metadata"]["tool_name"] == "web_search"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_message_object_overrides_kwargs(self, journal_setup):
|
||||
"""ToolMessage object fields take priority over kwargs."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
tool_msg = ToolMessage(
|
||||
content="result",
|
||||
tool_call_id="call_obj",
|
||||
name="tool_a",
|
||||
status="success",
|
||||
)
|
||||
# Pass different values in kwargs — ToolMessage should win
|
||||
j.on_tool_end(tool_msg, run_id=run_id, name="tool_b", tool_call_id="call_kwarg")
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_obj"
|
||||
assert messages[0]["content"]["name"] == "tool_a"
|
||||
assert messages[0]["metadata"]["tool_name"] == "tool_a"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_message_error_status(self, journal_setup):
|
||||
"""ToolMessage with status='error' propagates status to metadata."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
tool_msg = ToolMessage(
|
||||
content="something went wrong",
|
||||
tool_call_id="call_err",
|
||||
name="web_fetch",
|
||||
status="error",
|
||||
)
|
||||
j.on_tool_end(tool_msg, run_id=run_id)
|
||||
await j.flush()
|
||||
|
||||
events = await store.list_events("t1", "r1")
|
||||
tool_end = [e for e in events if e["event_type"] == "tool_end"][0]
|
||||
assert tool_end["metadata"]["status"] == "error"
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert messages[0]["content"]["status"] == "error"
|
||||
assert messages[0]["metadata"]["status"] == "error"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_message_fallback_to_cache(self, journal_setup):
|
||||
"""If ToolMessage has empty tool_call_id, fall back to cache from on_tool_start."""
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_tool_start({"name": "bash"}, "ls", run_id=run_id, tool_call_id="call_cached")
|
||||
tool_msg = ToolMessage(
|
||||
content="file list",
|
||||
tool_call_id="",
|
||||
name="bash",
|
||||
)
|
||||
j.on_tool_end(tool_msg, run_id=run_id)
|
||||
await j.flush()
|
||||
|
||||
messages = await store.list_messages("t1")
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_cached"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_error_produces_tool_result_message(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.on_tool_error(TimeoutError("timeout"), run_id=uuid4(), name="web_fetch", tool_call_id="call_1")
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["event_type"] == "tool_result"
|
||||
assert messages[0]["content"]["type"] == "tool"
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_1"
|
||||
assert "timeout" in messages[0]["content"]["content"]
|
||||
assert messages[0]["content"]["status"] == "error"
|
||||
assert messages[0]["metadata"]["status"] == "error"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_tool_error_uses_cached_tool_call_id(self, journal_setup):
|
||||
"""on_tool_error should fall back to cached tool_call_id from on_tool_start."""
|
||||
j, store = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_tool_start({"name": "web_fetch"}, "url", run_id=run_id, tool_call_id="call_cached")
|
||||
j.on_tool_error(TimeoutError("timeout"), run_id=run_id, name="web_fetch")
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["content"]["tool_call_id"] == "call_cached"
|
||||
|
||||
|
||||
def _make_base_messages():
|
||||
@@ -745,11 +883,12 @@ class TestLlmRequestResponse:
|
||||
assert not any(e["event_type"] == "llm_start" for e in events)
|
||||
|
||||
|
||||
class TestMiddlewareTrace:
|
||||
class TestMiddlewareEvents:
|
||||
@pytest.mark.anyio
|
||||
async def test_record_middleware(self, journal_setup):
|
||||
async def test_record_middleware_uses_middleware_category(self, journal_setup):
|
||||
j, store = journal_setup
|
||||
j.record_middleware(
|
||||
"title",
|
||||
name="TitleMiddleware",
|
||||
hook="after_model",
|
||||
action="generate_title",
|
||||
@@ -757,27 +896,60 @@ class TestMiddlewareTrace:
|
||||
)
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
mw_events = [e for e in events if e["event_type"] == "middleware"]
|
||||
mw_events = [e for e in events if e["event_type"] == "middleware:title"]
|
||||
assert len(mw_events) == 1
|
||||
assert mw_events[0]["category"] == "trace"
|
||||
assert mw_events[0]["category"] == "middleware"
|
||||
assert mw_events[0]["content"]["name"] == "TitleMiddleware"
|
||||
assert mw_events[0]["content"]["hook"] == "after_model"
|
||||
assert mw_events[0]["content"]["action"] == "generate_title"
|
||||
assert mw_events[0]["content"]["changes"]["title"] == "Test Title"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_middleware_events_not_in_messages(self, journal_setup):
|
||||
"""Middleware events should not appear in list_messages()."""
|
||||
j, store = journal_setup
|
||||
j.record_middleware(
|
||||
"title",
|
||||
name="TitleMiddleware",
|
||||
hook="after_model",
|
||||
action="generate_title",
|
||||
changes={"title": "Test"},
|
||||
)
|
||||
await j.flush()
|
||||
messages = await store.list_messages("t1")
|
||||
assert len(messages) == 0
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_middleware_tag_variants(self, journal_setup):
|
||||
"""Different middleware tags produce distinct event_types."""
|
||||
j, store = journal_setup
|
||||
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={})
|
||||
j.record_middleware("guardrail", name="GuardrailMiddleware", hook="before_tool", action="deny", changes={})
|
||||
await j.flush()
|
||||
events = await store.list_events("t1", "r1")
|
||||
event_types = {e["event_type"] for e in events}
|
||||
assert "middleware:title" in event_types
|
||||
assert "middleware:guardrail" in event_types
|
||||
|
||||
|
||||
class TestFullRunSequence:
|
||||
@pytest.mark.anyio
|
||||
async def test_complete_run_event_sequence(self):
|
||||
"""Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply."""
|
||||
"""Simulate a full run: user -> LLM -> tool_call -> tool_result -> LLM -> final reply.
|
||||
|
||||
All message events use checkpoint-aligned model_dump format.
|
||||
"""
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal("r1", "t1", store, flush_threshold=100)
|
||||
|
||||
# 1. Human message (written by worker, not journal)
|
||||
# 1. Human message (written by worker, using model_dump format)
|
||||
human_msg = HumanMessage(content="Search for quantum computing")
|
||||
await store.put(
|
||||
thread_id="t1", run_id="r1",
|
||||
event_type="human_message", category="message",
|
||||
content={"role": "user", "content": "Search for quantum computing"},
|
||||
content=human_msg.model_dump(),
|
||||
)
|
||||
j.set_first_human_message("Search for quantum computing")
|
||||
|
||||
@@ -805,7 +977,7 @@ class TestFullRunSequence:
|
||||
j.on_tool_end("Quantum computing results...", run_id=tool_id, name="web_search", tool_call_id="call_1")
|
||||
|
||||
# 5. Middleware: title generation
|
||||
j.record_middleware("TitleMiddleware", "after_model", "generate_title", {"title": "Quantum Computing"})
|
||||
j.record_middleware("title", name="TitleMiddleware", hook="after_model", action="generate_title", changes={"title": "Quantum Computing"})
|
||||
|
||||
# 6. Second LLM call -> final reply
|
||||
llm2_id = uuid4()
|
||||
@@ -824,18 +996,19 @@ class TestFullRunSequence:
|
||||
await asyncio.sleep(0.05)
|
||||
await j.flush()
|
||||
|
||||
# Verify message sequence (what gets exported for training)
|
||||
# Verify message sequence
|
||||
messages = await store.list_messages("t1")
|
||||
msg_types = [m["event_type"] for m in messages]
|
||||
assert msg_types == ["human_message", "ai_tool_call", "tool_result", "ai_message"]
|
||||
|
||||
# Verify message content format
|
||||
assert messages[0]["content"]["role"] == "user"
|
||||
assert messages[1]["content"]["role"] == "assistant"
|
||||
# Verify checkpoint-aligned format: all messages use "type" not "role"
|
||||
assert messages[0]["content"]["type"] == "human"
|
||||
assert messages[0]["content"]["content"] == "Search for quantum computing"
|
||||
assert messages[1]["content"]["type"] == "ai"
|
||||
assert "tool_calls" in messages[1]["content"]
|
||||
assert messages[2]["content"]["role"] == "tool"
|
||||
assert messages[2]["content"]["type"] == "tool"
|
||||
assert messages[2]["content"]["tool_call_id"] == "call_1"
|
||||
assert messages[3]["content"]["role"] == "assistant"
|
||||
assert messages[3]["content"]["type"] == "ai"
|
||||
assert messages[3]["content"]["content"] == "Here are the results about quantum computing..."
|
||||
|
||||
# Verify trace events
|
||||
@@ -845,10 +1018,14 @@ class TestFullRunSequence:
|
||||
assert "llm_response" in trace_types
|
||||
assert "tool_start" in trace_types
|
||||
assert "tool_end" in trace_types
|
||||
assert "middleware" in trace_types
|
||||
assert "llm_start" not in trace_types
|
||||
assert "llm_end" not in trace_types
|
||||
|
||||
# Verify middleware events are in their own category
|
||||
mw_events = [e for e in events if e["category"] == "middleware"]
|
||||
assert len(mw_events) == 1
|
||||
assert mw_events[0]["event_type"] == "middleware:title"
|
||||
|
||||
# Verify token accumulation
|
||||
data = j.get_completion_data()
|
||||
assert data["total_tokens"] == 420 # 120 + 300
|
||||
@@ -857,7 +1034,7 @@ class TestFullRunSequence:
|
||||
assert data["message_count"] == 1 # only final ai_message counts
|
||||
assert data["last_ai_message"] == "Here are the results about quantum computing..."
|
||||
|
||||
# Verify training data export is trivial
|
||||
training_messages = [m["content"] for m in messages]
|
||||
assert all(isinstance(m, dict) for m in training_messages)
|
||||
assert all("role" in m for m in training_messages)
|
||||
# Verify all message contents are checkpoint-aligned dicts with "type" field
|
||||
for m in messages:
|
||||
assert isinstance(m["content"], dict)
|
||||
assert "type" in m["content"]
|
||||
|
||||
Reference in New Issue
Block a user