feat: refine token usage display modes (#2329)

* feat: refine token usage display modes

* docs: clarify token usage accounting semantics

* fix: avoid duplicate subtask debug keys

* style: format token usage tests

* chore: address token attribution review feedback

* Update test_token_usage_middleware.py

* Update test_token_usage_middleware.py

* chore: simplify token attribution fallback

* fix token usage metadata follow-up handling

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
YuJitang
2026-05-04 09:56:16 +08:00
committed by GitHub
parent 82e7936d36
commit d02f762ab0
20 changed files with 2346 additions and 222 deletions
@@ -1,31 +1,270 @@
"""Middleware for logging LLM token usage."""
"""Middleware for logging token usage and annotating step attribution."""
from __future__ import annotations
import logging
from typing import override
from collections import defaultdict
from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution"
def _string_arg(value: Any) -> str | None:
if isinstance(value, str):
normalized = value.strip()
return normalized or None
return None
def _normalize_todos(value: Any) -> list[Todo]:
if not isinstance(value, list):
return []
normalized: list[Todo] = []
for item in value:
if not isinstance(item, dict):
continue
todo: Todo = {}
content = _string_arg(item.get("content"))
status = item.get("status")
if content is not None:
todo["content"] = content
if status in {"pending", "in_progress", "completed"}:
todo["status"] = status
normalized.append(todo)
return normalized
def _todo_action_kind(previous: Todo | None, current: Todo) -> str:
status = current.get("status")
previous_content = previous.get("content") if previous else None
current_content = current.get("content")
if previous is None:
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
if previous_content != current_content:
return "todo_update"
if status == "completed":
return "todo_complete"
if status == "in_progress":
return "todo_start"
return "todo_update"
def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]:
# This is the single source of truth for precise write_todos token
# attribution. The frontend intentionally falls back to a generic
# "Update to-do list" label when this metadata is missing or malformed.
previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list)
matched_previous_indices: set[int] = set()
for index, todo in enumerate(previous_todos):
content = todo.get("content")
if isinstance(content, str) and content:
previous_by_content[content].append((index, todo))
actions: list[dict[str, Any]] = []
for index, todo in enumerate(next_todos):
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
previous_match: Todo | None = None
content_matches = previous_by_content.get(content)
if content_matches:
while content_matches and content_matches[0][0] in matched_previous_indices:
content_matches.pop(0)
if content_matches:
previous_index, previous_match = content_matches.pop(0)
matched_previous_indices.add(previous_index)
if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices:
previous_match = previous_todos[index]
matched_previous_indices.add(index)
if previous_match is not None:
previous_content = previous_match.get("content")
previous_status = previous_match.get("status")
if previous_content == content and previous_status == todo.get("status"):
continue
actions.append(
{
"kind": _todo_action_kind(previous_match, todo),
"content": content,
}
)
for index, todo in enumerate(previous_todos):
if index in matched_previous_indices:
continue
content = todo.get("content")
if not isinstance(content, str) or not content:
continue
actions.append(
{
"kind": "todo_remove",
"content": content,
}
)
return actions
def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]:
name = _string_arg(tool_call.get("name")) or "unknown"
args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {}
tool_call_id = _string_arg(tool_call.get("id"))
if name == "write_todos":
next_todos = _normalize_todos(args.get("todos"))
actions = _build_todo_actions(todos, next_todos)
if not actions:
return [
{
"kind": "tool",
"tool_name": name,
"tool_call_id": tool_call_id,
}
]
return [
{
**action,
"tool_call_id": tool_call_id,
}
for action in actions
]
if name == "task":
return [
{
"kind": "subagent",
"description": _string_arg(args.get("description")),
"subagent_type": _string_arg(args.get("subagent_type")),
"tool_call_id": tool_call_id,
}
]
if name in {"web_search", "image_search"}:
query = _string_arg(args.get("query"))
return [
{
"kind": "search",
"tool_name": name,
"query": query,
"tool_call_id": tool_call_id,
}
]
if name == "present_files":
return [
{
"kind": "present_files",
"tool_call_id": tool_call_id,
}
]
if name == "ask_clarification":
return [
{
"kind": "clarification",
"tool_call_id": tool_call_id,
}
]
return [
{
"kind": "tool",
"tool_name": name,
"description": _string_arg(args.get("description")),
"tool_call_id": tool_call_id,
}
]
def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
if actions:
first_kind = actions[0].get("kind")
if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}:
return "todo_update"
if len(actions) == 1 and first_kind == "subagent":
return "subagent_dispatch"
return "tool_batch"
if message.content:
return "final_answer"
return "thinking"
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
current_todos = list(todos)
for raw_tool_call in tool_calls:
if not isinstance(raw_tool_call, dict):
continue
described_actions = _describe_tool_call(raw_tool_call, current_todos)
actions.extend(described_actions)
if raw_tool_call.get("name") == "write_todos":
args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {}
current_todos = _normalize_todos(args.get("todos"))
tool_call_ids: list[str] = []
for tool_call in tool_calls:
if not isinstance(tool_call, dict):
continue
tool_call_id = _string_arg(tool_call.get("id"))
if tool_call_id is not None:
tool_call_ids.append(tool_call_id)
return {
# Schema changes should remain additive where possible so older
# frontends can ignore unknown fields and fall back safely.
"version": 1,
"kind": _infer_step_kind(message, actions),
"shared_attribution": len(actions) > 1,
"tool_call_ids": tool_call_ids,
"actions": actions,
}
class TokenUsageMiddleware(AgentMiddleware):
"""Logs token usage from model response usage_metadata."""
"""Logs token usage from model responses and annotates the AI step."""
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._log_usage(state)
def _log_usage(self, state: AgentState) -> None:
def _apply(self, state: AgentState) -> dict | None:
messages = state.get("messages", [])
if not messages:
return None
last = messages[-1]
if not isinstance(last, AIMessage):
return None
usage = getattr(last, "usage_metadata", None)
if usage:
logger.info(
@@ -34,4 +273,22 @@ class TokenUsageMiddleware(AgentMiddleware):
usage.get("output_tokens", "?"),
usage.get("total_tokens", "?"),
)
return None
todos = state.get("todos") or []
attribution = _build_attribution(last, todos if isinstance(todos, list) else [])
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
return {"messages": [updated_msg]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
@override
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
return self._apply(state)
+98 -19
View File
@@ -264,25 +264,35 @@ class DeerFlowClient:
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
@staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
def _serialize_additional_kwargs(msg) -> dict[str, Any] | None:
"""Copy message additional_kwargs when present."""
additional_kwargs = getattr(msg, "additional_kwargs", None)
if isinstance(additional_kwargs, dict) and additional_kwargs:
return dict(additional_kwargs)
return None
@staticmethod
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI text event."""
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
if usage:
data["usage_metadata"] = usage
if additional_kwargs:
data["additional_kwargs"] = additional_kwargs
return StreamEvent(type="messages-tuple", data=data)
@staticmethod
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
"""Build a ``messages-tuple`` AI tool-calls event."""
return StreamEvent(
type="messages-tuple",
data={
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
},
)
data: dict[str, Any] = {
"type": "ai",
"content": "",
"id": msg_id,
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
}
if additional_kwargs:
data["additional_kwargs"] = additional_kwargs
return StreamEvent(type="messages-tuple", data=data)
@staticmethod
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
@@ -307,19 +317,30 @@ class DeerFlowClient:
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
if getattr(msg, "usage_metadata", None):
d["usage_metadata"] = msg.usage_metadata
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, ToolMessage):
return {
d = {
"type": "tool",
"content": DeerFlowClient._extract_text(msg.content),
"name": getattr(msg, "name", None),
"tool_call_id": getattr(msg, "tool_call_id", None),
"id": getattr(msg, "id", None),
}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, HumanMessage):
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
if isinstance(msg, SystemMessage):
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
d["additional_kwargs"] = additional_kwargs
return d
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
@staticmethod
@@ -542,6 +563,7 @@ class DeerFlowClient:
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}}
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
"""
@@ -564,6 +586,7 @@ class DeerFlowClient:
# in both the final ``messages`` chunk and the values snapshot —
# count it only on whichever arrives first.
counted_usage_ids: set[str] = set()
sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {}
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
@@ -593,6 +616,20 @@ class DeerFlowClient:
"total_tokens": total_tokens,
}
def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None:
if not additional_kwargs:
return None
if not msg_id:
return additional_kwargs
sent = sent_additional_kwargs_by_id.setdefault(msg_id, {})
delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value}
if not delta:
return None
sent.update(delta)
return delta
for item in self._agent.stream(
state,
config=config,
@@ -620,17 +657,31 @@ class DeerFlowClient:
if isinstance(msg_chunk, AIMessage):
text = self._extract_text(msg_chunk.content)
additional_kwargs = self._serialize_additional_kwargs(msg_chunk)
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
sent_additional_kwargs = False
if text:
if msg_id:
streamed_ids.add(msg_id)
yield self._ai_text_event(msg_id, text, counted_usage)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
if msg_chunk.tool_calls:
if msg_id:
streamed_ids.add(msg_id)
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg_chunk.tool_calls,
additional_kwargs_delta,
)
elif isinstance(msg_chunk, ToolMessage):
if msg_id:
@@ -653,17 +704,45 @@ class DeerFlowClient:
if msg_id and msg_id in streamed_ids:
if isinstance(msg, AIMessage):
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
additional_kwargs = self._serialize_additional_kwargs(msg)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
if additional_kwargs_delta:
# Metadata-only follow-up: ``messages-tuple`` has no
# dedicated attribution event, so clients should
# merge this empty-content AI event by message id
# and ignore it for text rendering.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
continue
if isinstance(msg, AIMessage):
counted_usage = _account_usage(msg_id, msg.usage_metadata)
additional_kwargs = self._serialize_additional_kwargs(msg)
sent_additional_kwargs = False
if msg.tool_calls:
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_tool_calls_event(
msg_id,
msg.tool_calls,
additional_kwargs_delta,
)
sent_additional_kwargs = bool(additional_kwargs_delta)
text = self._extract_text(msg.content)
if text:
yield self._ai_text_event(msg_id, text, counted_usage)
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
yield self._ai_text_event(
msg_id,
text,
counted_usage,
additional_kwargs_delta,
)
elif msg_id:
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
if not additional_kwargs_delta:
continue
# See the metadata-only follow-up convention above.
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
elif isinstance(msg, ToolMessage):
yield self._tool_message_event(msg)