mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
feat: refine token usage display modes (#2329)
* feat: refine token usage display modes * docs: clarify token usage accounting semantics * fix: avoid duplicate subtask debug keys * style: format token usage tests * chore: address token attribution review feedback * Update test_token_usage_middleware.py * Update test_token_usage_middleware.py * chore: simplify token attribution fallback * fix token usage metadata follow-up handling --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -437,6 +437,85 @@ class TestStream:
|
||||
call_kwargs = agent.stream.call_args.kwargs
|
||||
assert "messages" in call_kwargs["stream_mode"]
|
||||
|
||||
def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client):
|
||||
"""stream() emits a follow-up AI event when attribution metadata arrives via values."""
|
||||
assembled = AIMessage(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
agent = MagicMock()
|
||||
agent.stream.return_value = iter(
|
||||
[
|
||||
("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})),
|
||||
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-stream-kwargs"))
|
||||
|
||||
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
|
||||
assert any(event.data.get("content") == "Hello!" for event in ai_events)
|
||||
assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events)
|
||||
|
||||
def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client):
|
||||
"""stream() emits later attribution metadata even after earlier kwargs for the same id."""
|
||||
attribution = {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
assembled = AIMessage(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={
|
||||
"reasoning_content": "Thinking first.",
|
||||
"token_usage_attribution": attribution,
|
||||
},
|
||||
)
|
||||
agent = MagicMock()
|
||||
agent.stream.return_value = iter(
|
||||
[
|
||||
(
|
||||
"messages",
|
||||
(
|
||||
AIMessageChunk(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={"reasoning_content": "Thinking first."},
|
||||
),
|
||||
{},
|
||||
),
|
||||
),
|
||||
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-stream-kwargs-delta"))
|
||||
|
||||
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
|
||||
metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")]
|
||||
|
||||
assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."}
|
||||
assert metadata_events[1].data["content"] == ""
|
||||
assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution}
|
||||
|
||||
def test_chat_accumulates_streamed_deltas(self, client):
|
||||
"""chat() concatenates per-id deltas from messages mode."""
|
||||
agent = MagicMock()
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Tests for DeerFlowClient message serialization helpers."""
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
|
||||
def test_serialize_ai_message_preserves_additional_kwargs():
|
||||
message = AIMessage(
|
||||
content="done",
|
||||
additional_kwargs={
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
},
|
||||
usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15},
|
||||
)
|
||||
|
||||
serialized = DeerFlowClient._serialize_message(message)
|
||||
|
||||
assert serialized["type"] == "ai"
|
||||
assert serialized["usage_metadata"] == {
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 3,
|
||||
"total_tokens": 15,
|
||||
}
|
||||
assert serialized["additional_kwargs"] == {
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_serialize_human_message_preserves_additional_kwargs():
|
||||
message = HumanMessage(
|
||||
content="hello",
|
||||
additional_kwargs={"files": [{"name": "diagram.png"}]},
|
||||
)
|
||||
|
||||
serialized = DeerFlowClient._serialize_message(message)
|
||||
|
||||
assert serialized == {
|
||||
"type": "human",
|
||||
"content": "hello",
|
||||
"id": None,
|
||||
"additional_kwargs": {"files": [{"name": "diagram.png"}]},
|
||||
}
|
||||
@@ -1,32 +1,157 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
"""Tests for TokenUsageMiddleware attribution annotations."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||
from deerflow.agents.middlewares.token_usage_middleware import (
|
||||
TOKEN_USAGE_ATTRIBUTION_KEY,
|
||||
TokenUsageMiddleware,
|
||||
)
|
||||
|
||||
|
||||
def test_after_model_logs_usage_metadata_counts():
|
||||
middleware = TokenUsageMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="done",
|
||||
usage_metadata={
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
)
|
||||
def _make_runtime():
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "test-thread"}
|
||||
return runtime
|
||||
|
||||
|
||||
class TestTokenUsageMiddleware:
|
||||
def test_annotates_todo_updates_with_structured_actions(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "write_todos:1",
|
||||
"name": "write_todos",
|
||||
"args": {
|
||||
"todos": [
|
||||
{"content": "Inspect streaming path", "status": "completed"},
|
||||
{"content": "Design token attribution schema", "status": "in_progress"},
|
||||
]
|
||||
},
|
||||
}
|
||||
],
|
||||
usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120},
|
||||
)
|
||||
|
||||
state = {
|
||||
"messages": [message],
|
||||
"todos": [
|
||||
{"content": "Inspect streaming path", "status": "in_progress"},
|
||||
{"content": "Design token attribution schema", "status": "pending"},
|
||||
],
|
||||
}
|
||||
|
||||
result = middleware.after_model(state, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
updated_message = result["messages"][0]
|
||||
attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "tool_batch"
|
||||
assert attribution["shared_attribution"] is True
|
||||
assert attribution["tool_call_ids"] == ["write_todos:1"]
|
||||
assert attribution["actions"] == [
|
||||
{
|
||||
"kind": "todo_complete",
|
||||
"content": "Inspect streaming path",
|
||||
"tool_call_id": "write_todos:1",
|
||||
},
|
||||
{
|
||||
"kind": "todo_start",
|
||||
"content": "Design token attribution schema",
|
||||
"tool_call_id": "write_todos:1",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("deerflow.agents.middlewares.token_usage_middleware.logger.info") as info_mock:
|
||||
result = middleware.after_model(state=state, runtime=MagicMock())
|
||||
def test_annotates_subagent_and_search_steps(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "task:1",
|
||||
"name": "task",
|
||||
"args": {
|
||||
"description": "spec-coder patch message grouping",
|
||||
"subagent_type": "general-purpose",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "web_search:1",
|
||||
"name": "web_search",
|
||||
"args": {"query": "LangGraph useStream messages tuple"},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert result is None
|
||||
info_mock.assert_called_once_with(
|
||||
"LLM token usage: input=%s output=%s total=%s",
|
||||
10,
|
||||
5,
|
||||
15,
|
||||
)
|
||||
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "tool_batch"
|
||||
assert attribution["shared_attribution"] is True
|
||||
assert attribution["actions"] == [
|
||||
{
|
||||
"kind": "subagent",
|
||||
"description": "spec-coder patch message grouping",
|
||||
"subagent_type": "general-purpose",
|
||||
"tool_call_id": "task:1",
|
||||
},
|
||||
{
|
||||
"kind": "search",
|
||||
"tool_name": "web_search",
|
||||
"query": "LangGraph useStream messages tuple",
|
||||
"tool_call_id": "web_search:1",
|
||||
},
|
||||
]
|
||||
|
||||
def test_marks_final_answer_when_no_tools(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(content="Here is the final answer.")
|
||||
|
||||
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "final_answer"
|
||||
assert attribution["shared_attribution"] is False
|
||||
assert attribution["actions"] == []
|
||||
|
||||
def test_annotates_removed_todos(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "write_todos:remove",
|
||||
"name": "write_todos",
|
||||
"args": {
|
||||
"todos": [],
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
result = middleware.after_model(
|
||||
{
|
||||
"messages": [message],
|
||||
"todos": [
|
||||
{"content": "Archive obsolete plan", "status": "pending"},
|
||||
],
|
||||
},
|
||||
_make_runtime(),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "todo_update"
|
||||
assert attribution["shared_attribution"] is False
|
||||
assert attribution["actions"] == [
|
||||
{
|
||||
"kind": "todo_remove",
|
||||
"content": "Archive obsolete plan",
|
||||
"tool_call_id": "write_todos:remove",
|
||||
}
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user