mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
fix(runtime): persist run message summaries (#2850)
* fix(runtime): persist run message summaries (#2849) * fix(runtime): dedupe run message summaries
This commit is contained in:
@@ -20,12 +20,13 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Mapping
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from langchain_core.callbacks import BaseCallbackHandler
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
from langchain_core.messages import AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
from langchain_core.messages import AIMessage, AnyMessage, BaseMessage, HumanMessage, ToolMessage
|
||||||
from langgraph.types import Command
|
from langgraph.types import Command
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -71,6 +72,7 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
|
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
|
||||||
self._counted_llm_run_ids: set[str] = set()
|
self._counted_llm_run_ids: set[str] = set()
|
||||||
self._counted_external_source_ids: set[str] = set()
|
self._counted_external_source_ids: set[str] = set()
|
||||||
|
self._counted_message_llm_run_ids: set[str] = set()
|
||||||
|
|
||||||
# Convenience fields
|
# Convenience fields
|
||||||
self._last_ai_msg: str | None = None
|
self._last_ai_msg: str | None = None
|
||||||
@@ -86,6 +88,50 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
# -- Lifecycle callbacks --
|
# -- Lifecycle callbacks --
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _message_text(message: BaseMessage) -> str:
|
||||||
|
"""Extract displayable text from a message's mixed content shape."""
|
||||||
|
content = getattr(message, "content", None)
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, str):
|
||||||
|
parts.append(block)
|
||||||
|
elif isinstance(block, Mapping):
|
||||||
|
text = block.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
else:
|
||||||
|
nested = block.get("content")
|
||||||
|
if isinstance(nested, str):
|
||||||
|
parts.append(nested)
|
||||||
|
return "".join(parts)
|
||||||
|
if isinstance(content, Mapping):
|
||||||
|
for key in ("text", "content"):
|
||||||
|
value = content.get(key)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return value
|
||||||
|
|
||||||
|
text = getattr(message, "text", None)
|
||||||
|
if isinstance(text, str):
|
||||||
|
return text
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _record_message_summary(self, message: BaseMessage, *, caller: str | None = None) -> None:
|
||||||
|
"""Update run-level convenience fields for persisted run rows."""
|
||||||
|
self._msg_count += 1
|
||||||
|
|
||||||
|
# ``last_ai_message`` should represent the lead agent's user-facing
|
||||||
|
# answer. Middleware/subagent model calls and empty tool-call-only
|
||||||
|
# AI messages must not overwrite the last useful assistant text.
|
||||||
|
is_ai_message = isinstance(message, AIMessage) or getattr(message, "type", None) == "ai"
|
||||||
|
if is_ai_message and (caller is None or caller == "lead_agent"):
|
||||||
|
text = self._message_text(message).strip()
|
||||||
|
if text:
|
||||||
|
self._last_ai_msg = text[:2000]
|
||||||
|
|
||||||
def on_chain_start(
|
def on_chain_start(
|
||||||
self,
|
self,
|
||||||
serialized: dict[str, Any],
|
serialized: dict[str, Any],
|
||||||
@@ -164,6 +210,7 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
content=m.model_dump(),
|
content=m.model_dump(),
|
||||||
metadata={"caller": caller},
|
metadata={"caller": caller},
|
||||||
)
|
)
|
||||||
|
self._record_message_summary(m, caller=caller)
|
||||||
break
|
break
|
||||||
if self._first_human_msg:
|
if self._first_human_msg:
|
||||||
break
|
break
|
||||||
@@ -222,6 +269,8 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
"llm_call_index": call_index,
|
"llm_call_index": call_index,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
if rid not in self._counted_message_llm_run_ids:
|
||||||
|
self._record_message_summary(message, caller=caller)
|
||||||
|
|
||||||
# Token accumulation (dedup by langchain run_id to avoid double-counting
|
# Token accumulation (dedup by langchain run_id to avoid double-counting
|
||||||
# when the callback fires more than once for the same response)
|
# when the callback fires more than once for the same response)
|
||||||
@@ -245,6 +294,9 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
else:
|
else:
|
||||||
self._lead_agent_tokens += total_tk
|
self._lead_agent_tokens += total_tk
|
||||||
|
|
||||||
|
if messages:
|
||||||
|
self._counted_message_llm_run_ids.add(str(run_id))
|
||||||
|
|
||||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
self._llm_start_times.pop(str(run_id), None)
|
self._llm_start_times.pop(str(run_id), None)
|
||||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||||
@@ -260,12 +312,14 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
if isinstance(output, ToolMessage):
|
if isinstance(output, ToolMessage):
|
||||||
msg = cast(ToolMessage, output)
|
msg = cast(ToolMessage, output)
|
||||||
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
self._put(event_type="llm.tool.result", category="message", content=msg.model_dump())
|
||||||
|
self._record_message_summary(msg)
|
||||||
elif isinstance(output, Command):
|
elif isinstance(output, Command):
|
||||||
cmd = cast(Command, output)
|
cmd = cast(Command, output)
|
||||||
messages = cmd.update.get("messages", [])
|
messages = cmd.update.get("messages", [])
|
||||||
for message in messages:
|
for message in messages:
|
||||||
if isinstance(message, BaseMessage):
|
if isinstance(message, BaseMessage):
|
||||||
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
self._put(event_type="llm.tool.result", category="message", content=message.model_dump())
|
||||||
|
self._record_message_summary(message)
|
||||||
else:
|
else:
|
||||||
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
logger.warning(f"on_tool_end {run_id}: command update message is not BaseMessage: {type(message)}")
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -339,6 +339,99 @@ class TestConvenienceFields:
|
|||||||
data = j.get_completion_data()
|
data = j.get_completion_data()
|
||||||
assert data["first_human_message"] == "What is AI?"
|
assert data["first_human_message"] == "What is AI?"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_completion_data_counts_human_ai_and_tool_messages(self, journal_setup):
|
||||||
|
from langchain_core.messages import HumanMessage, ToolMessage
|
||||||
|
|
||||||
|
j, _ = journal_setup
|
||||||
|
j.on_chat_model_start({}, [[HumanMessage(content="Question")]], run_id=uuid4(), tags=["lead_agent"])
|
||||||
|
j.on_llm_end(_make_llm_response("Answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||||
|
j.on_tool_end(ToolMessage(content="Tool result", tool_call_id="call_1", name="search"), run_id=uuid4())
|
||||||
|
|
||||||
|
data = j.get_completion_data()
|
||||||
|
|
||||||
|
assert data["message_count"] == 3
|
||||||
|
assert data["first_human_message"] == "Question"
|
||||||
|
assert data["last_ai_message"] == "Answer"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_tool_call_only_ai_does_not_clear_last_ai_message(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
j.on_llm_end(_make_llm_response("Useful answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||||
|
j.on_llm_end(
|
||||||
|
_make_llm_response("", tool_calls=[{"id": "call_1", "name": "search", "args": {}}]),
|
||||||
|
run_id=uuid4(),
|
||||||
|
parent_run_id=None,
|
||||||
|
tags=["lead_agent"],
|
||||||
|
)
|
||||||
|
|
||||||
|
data = j.get_completion_data()
|
||||||
|
|
||||||
|
assert data["message_count"] == 2
|
||||||
|
assert data["last_ai_message"] == "Useful answer"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_last_ai_message_extracts_mixed_content_without_extra_newlines(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
j.on_llm_end(
|
||||||
|
_make_llm_response(
|
||||||
|
[
|
||||||
|
{"type": "text", "text": "First "},
|
||||||
|
{"type": "text", "content": "second"},
|
||||||
|
" third",
|
||||||
|
{"type": "image", "url": "ignored"},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
run_id=uuid4(),
|
||||||
|
parent_run_id=None,
|
||||||
|
tags=["lead_agent"],
|
||||||
|
)
|
||||||
|
|
||||||
|
data = j.get_completion_data()
|
||||||
|
|
||||||
|
assert data["message_count"] == 1
|
||||||
|
assert data["last_ai_message"] == "First second third"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_last_ai_message_extracts_mapping_content(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
j.on_llm_end(_make_llm_response({"content": "Nested answer"}), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||||
|
|
||||||
|
data = j.get_completion_data()
|
||||||
|
|
||||||
|
assert data["message_count"] == 1
|
||||||
|
assert data["last_ai_message"] == "Nested answer"
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_duplicate_llm_run_id_does_not_double_count_message_summary(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
run_id = uuid4()
|
||||||
|
|
||||||
|
j.on_llm_end(_make_llm_response("Answer", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||||
|
j.on_llm_end(
|
||||||
|
_make_llm_response("Answer", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
||||||
|
run_id=run_id,
|
||||||
|
parent_run_id=None,
|
||||||
|
tags=["lead_agent"],
|
||||||
|
)
|
||||||
|
|
||||||
|
data = j.get_completion_data()
|
||||||
|
|
||||||
|
assert data["message_count"] == 1
|
||||||
|
assert data["last_ai_message"] == "Answer"
|
||||||
|
assert data["total_tokens"] == 15
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_subagent_ai_does_not_overwrite_lead_last_ai_message(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
j.on_llm_end(_make_llm_response("Lead answer"), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||||
|
j.on_llm_end(_make_llm_response("Subagent detail"), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
||||||
|
|
||||||
|
data = j.get_completion_data()
|
||||||
|
|
||||||
|
assert data["message_count"] == 2
|
||||||
|
assert data["last_ai_message"] == "Lead answer"
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_completion_data(self, journal_setup):
|
async def test_get_completion_data(self, journal_setup):
|
||||||
j, _ = journal_setup
|
j, _ = journal_setup
|
||||||
|
|||||||
Reference in New Issue
Block a user