feat: refine token usage display modes (#2329)

* feat: refine token usage display modes

* docs: clarify token usage accounting semantics

* fix: avoid duplicate subtask debug keys

* style: format token usage tests

* chore: address token attribution review feedback

* Update test_token_usage_middleware.py

* Update test_token_usage_middleware.py

* chore: simplify token attribution fallback

* fix token usage metadata follow-up handling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
YuJitang
2026-05-04 09:56:16 +08:00
committed by GitHub
parent 82e7936d36
commit d02f762ab0
20 changed files with 2346 additions and 222 deletions
+79
View File
@@ -437,6 +437,85 @@ class TestStream:
call_kwargs = agent.stream.call_args.kwargs
assert "messages" in call_kwargs["stream_mode"]
def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client):
"""stream() emits a follow-up AI event when attribution metadata arrives via values."""
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
assert any(event.data.get("content") == "Hello!" for event in ai_events)
assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events)
def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client):
"""stream() emits later attribution metadata even after earlier kwargs for the same id."""
attribution = {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
assembled = AIMessage(
content="Hello!",
id="ai-1",
additional_kwargs={
"reasoning_content": "Thinking first.",
"token_usage_attribution": attribution,
},
)
agent = MagicMock()
agent.stream.return_value = iter(
[
(
"messages",
(
AIMessageChunk(
content="Hello!",
id="ai-1",
additional_kwargs={"reasoning_content": "Thinking first."},
),
{},
),
),
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
]
)
with (
patch.object(client, "_ensure_agent"),
patch.object(client, "_agent", agent),
):
events = list(client.stream("hi", thread_id="t-stream-kwargs-delta"))
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")]
assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."}
assert metadata_events[1].data["content"] == ""
assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution}
def test_chat_accumulates_streamed_deltas(self, client):
"""chat() concatenates per-id deltas from messages mode."""
agent = MagicMock()
@@ -0,0 +1,53 @@
"""Tests for DeerFlowClient message serialization helpers."""
from langchain_core.messages import AIMessage, HumanMessage
from deerflow.client import DeerFlowClient
def test_serialize_ai_message_preserves_additional_kwargs():
message = AIMessage(
content="done",
additional_kwargs={
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
},
usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized["type"] == "ai"
assert serialized["usage_metadata"] == {
"input_tokens": 12,
"output_tokens": 3,
"total_tokens": 15,
}
assert serialized["additional_kwargs"] == {
"token_usage_attribution": {
"version": 1,
"kind": "final_answer",
"shared_attribution": False,
"actions": [],
}
}
def test_serialize_human_message_preserves_additional_kwargs():
message = HumanMessage(
content="hello",
additional_kwargs={"files": [{"name": "diagram.png"}]},
)
serialized = DeerFlowClient._serialize_message(message)
assert serialized == {
"type": "human",
"content": "hello",
"id": None,
"additional_kwargs": {"files": [{"name": "diagram.png"}]},
}
+149 -24
View File
@@ -1,32 +1,157 @@
from unittest.mock import MagicMock, patch
"""Tests for TokenUsageMiddleware attribution annotations."""
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
TokenUsageMiddleware,
)
def test_after_model_logs_usage_metadata_counts():
middleware = TokenUsageMiddleware()
state = {
"messages": [
AIMessage(
content="done",
usage_metadata={
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
},
)
def _make_runtime():
runtime = MagicMock()
runtime.context = {"thread_id": "test-thread"}
return runtime
class TestTokenUsageMiddleware:
def test_annotates_todo_updates_with_structured_actions(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "write_todos:1",
"name": "write_todos",
"args": {
"todos": [
{"content": "Inspect streaming path", "status": "completed"},
{"content": "Design token attribution schema", "status": "in_progress"},
]
},
}
],
usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120},
)
state = {
"messages": [message],
"todos": [
{"content": "Inspect streaming path", "status": "in_progress"},
{"content": "Design token attribution schema", "status": "pending"},
],
}
result = middleware.after_model(state, _make_runtime())
assert result is not None
updated_message = result["messages"][0]
attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "tool_batch"
assert attribution["shared_attribution"] is True
assert attribution["tool_call_ids"] == ["write_todos:1"]
assert attribution["actions"] == [
{
"kind": "todo_complete",
"content": "Inspect streaming path",
"tool_call_id": "write_todos:1",
},
{
"kind": "todo_start",
"content": "Design token attribution schema",
"tool_call_id": "write_todos:1",
},
]
}
with patch("deerflow.agents.middlewares.token_usage_middleware.logger.info") as info_mock:
result = middleware.after_model(state=state, runtime=MagicMock())
def test_annotates_subagent_and_search_steps(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "task:1",
"name": "task",
"args": {
"description": "spec-coder patch message grouping",
"subagent_type": "general-purpose",
},
},
{
"id": "web_search:1",
"name": "web_search",
"args": {"query": "LangGraph useStream messages tuple"},
},
],
)
assert result is None
info_mock.assert_called_once_with(
"LLM token usage: input=%s output=%s total=%s",
10,
5,
15,
)
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "tool_batch"
assert attribution["shared_attribution"] is True
assert attribution["actions"] == [
{
"kind": "subagent",
"description": "spec-coder patch message grouping",
"subagent_type": "general-purpose",
"tool_call_id": "task:1",
},
{
"kind": "search",
"tool_name": "web_search",
"query": "LangGraph useStream messages tuple",
"tool_call_id": "web_search:1",
},
]
def test_marks_final_answer_when_no_tools(self):
middleware = TokenUsageMiddleware()
message = AIMessage(content="Here is the final answer.")
result = middleware.after_model({"messages": [message]}, _make_runtime())
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "final_answer"
assert attribution["shared_attribution"] is False
assert attribution["actions"] == []
def test_annotates_removed_todos(self):
middleware = TokenUsageMiddleware()
message = AIMessage(
content="",
tool_calls=[
{
"id": "write_todos:remove",
"name": "write_todos",
"args": {
"todos": [],
},
}
],
)
result = middleware.after_model(
{
"messages": [message],
"todos": [
{"content": "Archive obsolete plan", "status": "pending"},
],
},
_make_runtime(),
)
assert result is not None
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
assert attribution["kind"] == "todo_update"
assert attribution["shared_attribution"] is False
assert attribution["actions"] == [
{
"kind": "todo_remove",
"content": "Archive obsolete plan",
"tool_call_id": "write_todos:remove",
}
]