mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-21 23:46:50 +00:00
feat: refine token usage display modes (#2329)
* feat: refine token usage display modes * docs: clarify token usage accounting semantics * fix: avoid duplicate subtask debug keys * style: format token usage tests * chore: address token attribution review feedback * Update test_token_usage_middleware.py * Update test_token_usage_middleware.py * chore: simplify token attribution fallback * fix token usage metadata follow-up handling --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -1,31 +1,270 @@
|
||||
"""Middleware for logging LLM token usage."""
|
||||
"""Middleware for logging token usage and annotating step attribution."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import override
|
||||
from collections import defaultdict
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.todo import Todo
|
||||
from langchain_core.messages import AIMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOKEN_USAGE_ATTRIBUTION_KEY = "token_usage_attribution"
|
||||
|
||||
|
||||
def _string_arg(value: Any) -> str | None:
|
||||
if isinstance(value, str):
|
||||
normalized = value.strip()
|
||||
return normalized or None
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_todos(value: Any) -> list[Todo]:
|
||||
if not isinstance(value, list):
|
||||
return []
|
||||
|
||||
normalized: list[Todo] = []
|
||||
for item in value:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
|
||||
todo: Todo = {}
|
||||
content = _string_arg(item.get("content"))
|
||||
status = item.get("status")
|
||||
|
||||
if content is not None:
|
||||
todo["content"] = content
|
||||
if status in {"pending", "in_progress", "completed"}:
|
||||
todo["status"] = status
|
||||
|
||||
normalized.append(todo)
|
||||
|
||||
return normalized
|
||||
|
||||
|
||||
def _todo_action_kind(previous: Todo | None, current: Todo) -> str:
|
||||
status = current.get("status")
|
||||
previous_content = previous.get("content") if previous else None
|
||||
current_content = current.get("content")
|
||||
|
||||
if previous is None:
|
||||
if status == "completed":
|
||||
return "todo_complete"
|
||||
if status == "in_progress":
|
||||
return "todo_start"
|
||||
return "todo_update"
|
||||
|
||||
if previous_content != current_content:
|
||||
return "todo_update"
|
||||
|
||||
if status == "completed":
|
||||
return "todo_complete"
|
||||
if status == "in_progress":
|
||||
return "todo_start"
|
||||
return "todo_update"
|
||||
|
||||
|
||||
def _build_todo_actions(previous_todos: list[Todo], next_todos: list[Todo]) -> list[dict[str, Any]]:
|
||||
# This is the single source of truth for precise write_todos token
|
||||
# attribution. The frontend intentionally falls back to a generic
|
||||
# "Update to-do list" label when this metadata is missing or malformed.
|
||||
previous_by_content: dict[str, list[tuple[int, Todo]]] = defaultdict(list)
|
||||
matched_previous_indices: set[int] = set()
|
||||
|
||||
for index, todo in enumerate(previous_todos):
|
||||
content = todo.get("content")
|
||||
if isinstance(content, str) and content:
|
||||
previous_by_content[content].append((index, todo))
|
||||
|
||||
actions: list[dict[str, Any]] = []
|
||||
|
||||
for index, todo in enumerate(next_todos):
|
||||
content = todo.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
continue
|
||||
|
||||
previous_match: Todo | None = None
|
||||
content_matches = previous_by_content.get(content)
|
||||
if content_matches:
|
||||
while content_matches and content_matches[0][0] in matched_previous_indices:
|
||||
content_matches.pop(0)
|
||||
if content_matches:
|
||||
previous_index, previous_match = content_matches.pop(0)
|
||||
matched_previous_indices.add(previous_index)
|
||||
|
||||
if previous_match is None and index < len(previous_todos) and index not in matched_previous_indices:
|
||||
previous_match = previous_todos[index]
|
||||
matched_previous_indices.add(index)
|
||||
|
||||
if previous_match is not None:
|
||||
previous_content = previous_match.get("content")
|
||||
previous_status = previous_match.get("status")
|
||||
if previous_content == content and previous_status == todo.get("status"):
|
||||
continue
|
||||
|
||||
actions.append(
|
||||
{
|
||||
"kind": _todo_action_kind(previous_match, todo),
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
for index, todo in enumerate(previous_todos):
|
||||
if index in matched_previous_indices:
|
||||
continue
|
||||
|
||||
content = todo.get("content")
|
||||
if not isinstance(content, str) or not content:
|
||||
continue
|
||||
|
||||
actions.append(
|
||||
{
|
||||
"kind": "todo_remove",
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
|
||||
return actions
|
||||
|
||||
|
||||
def _describe_tool_call(tool_call: dict[str, Any], todos: list[Todo]) -> list[dict[str, Any]]:
|
||||
name = _string_arg(tool_call.get("name")) or "unknown"
|
||||
args = tool_call.get("args") if isinstance(tool_call.get("args"), dict) else {}
|
||||
tool_call_id = _string_arg(tool_call.get("id"))
|
||||
|
||||
if name == "write_todos":
|
||||
next_todos = _normalize_todos(args.get("todos"))
|
||||
actions = _build_todo_actions(todos, next_todos)
|
||||
if not actions:
|
||||
return [
|
||||
{
|
||||
"kind": "tool",
|
||||
"tool_name": name,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
return [
|
||||
{
|
||||
**action,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
for action in actions
|
||||
]
|
||||
|
||||
if name == "task":
|
||||
return [
|
||||
{
|
||||
"kind": "subagent",
|
||||
"description": _string_arg(args.get("description")),
|
||||
"subagent_type": _string_arg(args.get("subagent_type")),
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
if name in {"web_search", "image_search"}:
|
||||
query = _string_arg(args.get("query"))
|
||||
return [
|
||||
{
|
||||
"kind": "search",
|
||||
"tool_name": name,
|
||||
"query": query,
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
if name == "present_files":
|
||||
return [
|
||||
{
|
||||
"kind": "present_files",
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
if name == "ask_clarification":
|
||||
return [
|
||||
{
|
||||
"kind": "clarification",
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
return [
|
||||
{
|
||||
"kind": "tool",
|
||||
"tool_name": name,
|
||||
"description": _string_arg(args.get("description")),
|
||||
"tool_call_id": tool_call_id,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
||||
if actions:
|
||||
first_kind = actions[0].get("kind")
|
||||
if len(actions) == 1 and first_kind in {"todo_start", "todo_complete", "todo_update", "todo_remove"}:
|
||||
return "todo_update"
|
||||
if len(actions) == 1 and first_kind == "subagent":
|
||||
return "subagent_dispatch"
|
||||
return "tool_batch"
|
||||
|
||||
if message.content:
|
||||
return "final_answer"
|
||||
return "thinking"
|
||||
|
||||
|
||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
actions: list[dict[str, Any]] = []
|
||||
current_todos = list(todos)
|
||||
|
||||
for raw_tool_call in tool_calls:
|
||||
if not isinstance(raw_tool_call, dict):
|
||||
continue
|
||||
|
||||
described_actions = _describe_tool_call(raw_tool_call, current_todos)
|
||||
actions.extend(described_actions)
|
||||
|
||||
if raw_tool_call.get("name") == "write_todos":
|
||||
args = raw_tool_call.get("args") if isinstance(raw_tool_call.get("args"), dict) else {}
|
||||
current_todos = _normalize_todos(args.get("todos"))
|
||||
|
||||
tool_call_ids: list[str] = []
|
||||
for tool_call in tool_calls:
|
||||
if not isinstance(tool_call, dict):
|
||||
continue
|
||||
|
||||
tool_call_id = _string_arg(tool_call.get("id"))
|
||||
if tool_call_id is not None:
|
||||
tool_call_ids.append(tool_call_id)
|
||||
|
||||
return {
|
||||
# Schema changes should remain additive where possible so older
|
||||
# frontends can ignore unknown fields and fall back safely.
|
||||
"version": 1,
|
||||
"kind": _infer_step_kind(message, actions),
|
||||
"shared_attribution": len(actions) > 1,
|
||||
"tool_call_ids": tool_call_ids,
|
||||
"actions": actions,
|
||||
}
|
||||
|
||||
|
||||
class TokenUsageMiddleware(AgentMiddleware):
|
||||
"""Logs token usage from model response usage_metadata."""
|
||||
"""Logs token usage from model responses and annotates the AI step."""
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._log_usage(state)
|
||||
|
||||
def _log_usage(self, state: AgentState) -> None:
|
||||
def _apply(self, state: AgentState) -> dict | None:
|
||||
messages = state.get("messages", [])
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
return None
|
||||
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
if usage:
|
||||
logger.info(
|
||||
@@ -34,4 +273,22 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
usage.get("output_tokens", "?"),
|
||||
usage.get("total_tokens", "?"),
|
||||
)
|
||||
return None
|
||||
|
||||
todos = state.get("todos") or []
|
||||
attribution = _build_attribution(last, todos if isinstance(todos, list) else [])
|
||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
||||
|
||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
||||
return None
|
||||
|
||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
return {"messages": [updated_msg]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state)
|
||||
|
||||
@override
|
||||
async def aafter_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
return self._apply(state)
|
||||
|
||||
@@ -264,25 +264,35 @@ class DeerFlowClient:
|
||||
return [{"name": tc["name"], "args": tc["args"], "id": tc.get("id")} for tc in tool_calls]
|
||||
|
||||
@staticmethod
|
||||
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI text event, attaching usage when present."""
|
||||
def _serialize_additional_kwargs(msg) -> dict[str, Any] | None:
|
||||
"""Copy message additional_kwargs when present."""
|
||||
additional_kwargs = getattr(msg, "additional_kwargs", None)
|
||||
if isinstance(additional_kwargs, dict) and additional_kwargs:
|
||||
return dict(additional_kwargs)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _ai_text_event(msg_id: str | None, text: str, usage: dict | None, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI text event."""
|
||||
data: dict[str, Any] = {"type": "ai", "content": text, "id": msg_id}
|
||||
if usage:
|
||||
data["usage_metadata"] = usage
|
||||
if additional_kwargs:
|
||||
data["additional_kwargs"] = additional_kwargs
|
||||
return StreamEvent(type="messages-tuple", data=data)
|
||||
|
||||
@staticmethod
|
||||
def _ai_tool_calls_event(msg_id: str | None, tool_calls) -> "StreamEvent":
|
||||
def _ai_tool_calls_event(msg_id: str | None, tool_calls, additional_kwargs: dict[str, Any] | None = None) -> "StreamEvent":
|
||||
"""Build a ``messages-tuple`` AI tool-calls event."""
|
||||
return StreamEvent(
|
||||
type="messages-tuple",
|
||||
data={
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
||||
},
|
||||
)
|
||||
data: dict[str, Any] = {
|
||||
"type": "ai",
|
||||
"content": "",
|
||||
"id": msg_id,
|
||||
"tool_calls": DeerFlowClient._serialize_tool_calls(tool_calls),
|
||||
}
|
||||
if additional_kwargs:
|
||||
data["additional_kwargs"] = additional_kwargs
|
||||
return StreamEvent(type="messages-tuple", data=data)
|
||||
|
||||
@staticmethod
|
||||
def _tool_message_event(msg: ToolMessage) -> "StreamEvent":
|
||||
@@ -307,19 +317,30 @@ class DeerFlowClient:
|
||||
d["tool_calls"] = DeerFlowClient._serialize_tool_calls(msg.tool_calls)
|
||||
if getattr(msg, "usage_metadata", None):
|
||||
d["usage_metadata"] = msg.usage_metadata
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
if isinstance(msg, ToolMessage):
|
||||
return {
|
||||
d = {
|
||||
"type": "tool",
|
||||
"content": DeerFlowClient._extract_text(msg.content),
|
||||
"name": getattr(msg, "name", None),
|
||||
"tool_call_id": getattr(msg, "tool_call_id", None),
|
||||
"id": getattr(msg, "id", None),
|
||||
}
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
if isinstance(msg, HumanMessage):
|
||||
return {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
d = {"type": "human", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
if isinstance(msg, SystemMessage):
|
||||
return {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
d = {"type": "system", "content": msg.content, "id": getattr(msg, "id", None)}
|
||||
if additional_kwargs := DeerFlowClient._serialize_additional_kwargs(msg):
|
||||
d["additional_kwargs"] = additional_kwargs
|
||||
return d
|
||||
return {"type": "unknown", "content": str(msg), "id": getattr(msg, "id", None)}
|
||||
|
||||
@staticmethod
|
||||
@@ -542,6 +563,7 @@ class DeerFlowClient:
|
||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str}
|
||||
- type="messages-tuple" data={"type": "ai", "content": <delta>, "id": str, "usage_metadata": {...}}
|
||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "tool_calls": [...]}
|
||||
- type="messages-tuple" data={"type": "ai", "content": "", "id": str, "additional_kwargs": {...}}
|
||||
- type="messages-tuple" data={"type": "tool", "content": str, "name": str, "tool_call_id": str, "id": str}
|
||||
- type="end" data={"usage": {"input_tokens": int, "output_tokens": int, "total_tokens": int}}
|
||||
"""
|
||||
@@ -564,6 +586,7 @@ class DeerFlowClient:
|
||||
# in both the final ``messages`` chunk and the values snapshot —
|
||||
# count it only on whichever arrives first.
|
||||
counted_usage_ids: set[str] = set()
|
||||
sent_additional_kwargs_by_id: dict[str, dict[str, Any]] = {}
|
||||
cumulative_usage: dict[str, int] = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||
|
||||
def _account_usage(msg_id: str | None, usage: Any) -> dict | None:
|
||||
@@ -593,6 +616,20 @@ class DeerFlowClient:
|
||||
"total_tokens": total_tokens,
|
||||
}
|
||||
|
||||
def _unsent_additional_kwargs(msg_id: str | None, additional_kwargs: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
if not additional_kwargs:
|
||||
return None
|
||||
if not msg_id:
|
||||
return additional_kwargs
|
||||
|
||||
sent = sent_additional_kwargs_by_id.setdefault(msg_id, {})
|
||||
delta = {key: value for key, value in additional_kwargs.items() if sent.get(key) != value}
|
||||
if not delta:
|
||||
return None
|
||||
|
||||
sent.update(delta)
|
||||
return delta
|
||||
|
||||
for item in self._agent.stream(
|
||||
state,
|
||||
config=config,
|
||||
@@ -620,17 +657,31 @@ class DeerFlowClient:
|
||||
|
||||
if isinstance(msg_chunk, AIMessage):
|
||||
text = self._extract_text(msg_chunk.content)
|
||||
additional_kwargs = self._serialize_additional_kwargs(msg_chunk)
|
||||
counted_usage = _account_usage(msg_id, msg_chunk.usage_metadata)
|
||||
sent_additional_kwargs = False
|
||||
|
||||
if text:
|
||||
if msg_id:
|
||||
streamed_ids.add(msg_id)
|
||||
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_text_event(
|
||||
msg_id,
|
||||
text,
|
||||
counted_usage,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
sent_additional_kwargs = bool(additional_kwargs_delta)
|
||||
|
||||
if msg_chunk.tool_calls:
|
||||
if msg_id:
|
||||
streamed_ids.add(msg_id)
|
||||
yield self._ai_tool_calls_event(msg_id, msg_chunk.tool_calls)
|
||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_tool_calls_event(
|
||||
msg_id,
|
||||
msg_chunk.tool_calls,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
|
||||
elif isinstance(msg_chunk, ToolMessage):
|
||||
if msg_id:
|
||||
@@ -653,17 +704,45 @@ class DeerFlowClient:
|
||||
if msg_id and msg_id in streamed_ids:
|
||||
if isinstance(msg, AIMessage):
|
||||
_account_usage(msg_id, getattr(msg, "usage_metadata", None))
|
||||
additional_kwargs = self._serialize_additional_kwargs(msg)
|
||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
if additional_kwargs_delta:
|
||||
# Metadata-only follow-up: ``messages-tuple`` has no
|
||||
# dedicated attribution event, so clients should
|
||||
# merge this empty-content AI event by message id
|
||||
# and ignore it for text rendering.
|
||||
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
|
||||
continue
|
||||
|
||||
if isinstance(msg, AIMessage):
|
||||
counted_usage = _account_usage(msg_id, msg.usage_metadata)
|
||||
additional_kwargs = self._serialize_additional_kwargs(msg)
|
||||
sent_additional_kwargs = False
|
||||
|
||||
if msg.tool_calls:
|
||||
yield self._ai_tool_calls_event(msg_id, msg.tool_calls)
|
||||
additional_kwargs_delta = _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_tool_calls_event(
|
||||
msg_id,
|
||||
msg.tool_calls,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
sent_additional_kwargs = bool(additional_kwargs_delta)
|
||||
|
||||
text = self._extract_text(msg.content)
|
||||
if text:
|
||||
yield self._ai_text_event(msg_id, text, counted_usage)
|
||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
yield self._ai_text_event(
|
||||
msg_id,
|
||||
text,
|
||||
counted_usage,
|
||||
additional_kwargs_delta,
|
||||
)
|
||||
elif msg_id:
|
||||
additional_kwargs_delta = None if sent_additional_kwargs else _unsent_additional_kwargs(msg_id, additional_kwargs)
|
||||
if not additional_kwargs_delta:
|
||||
continue
|
||||
# See the metadata-only follow-up convention above.
|
||||
yield self._ai_text_event(msg_id, "", None, additional_kwargs_delta)
|
||||
|
||||
elif isinstance(msg, ToolMessage):
|
||||
yield self._tool_message_event(msg)
|
||||
|
||||
@@ -437,6 +437,85 @@ class TestStream:
|
||||
call_kwargs = agent.stream.call_args.kwargs
|
||||
assert "messages" in call_kwargs["stream_mode"]
|
||||
|
||||
def test_stream_emits_additional_kwargs_updates_for_streamed_ai_messages(self, client):
|
||||
"""stream() emits a follow-up AI event when attribution metadata arrives via values."""
|
||||
assembled = AIMessage(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
},
|
||||
)
|
||||
agent = MagicMock()
|
||||
agent.stream.return_value = iter(
|
||||
[
|
||||
("messages", (AIMessageChunk(content="Hello!", id="ai-1"), {})),
|
||||
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-stream-kwargs"))
|
||||
|
||||
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
|
||||
assert any(event.data.get("content") == "Hello!" for event in ai_events)
|
||||
assert any(event.data.get("additional_kwargs", {}).get("token_usage_attribution", {}).get("kind") == "final_answer" for event in ai_events)
|
||||
|
||||
def test_stream_emits_new_additional_kwargs_after_prior_metadata(self, client):
|
||||
"""stream() emits later attribution metadata even after earlier kwargs for the same id."""
|
||||
attribution = {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
assembled = AIMessage(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={
|
||||
"reasoning_content": "Thinking first.",
|
||||
"token_usage_attribution": attribution,
|
||||
},
|
||||
)
|
||||
agent = MagicMock()
|
||||
agent.stream.return_value = iter(
|
||||
[
|
||||
(
|
||||
"messages",
|
||||
(
|
||||
AIMessageChunk(
|
||||
content="Hello!",
|
||||
id="ai-1",
|
||||
additional_kwargs={"reasoning_content": "Thinking first."},
|
||||
),
|
||||
{},
|
||||
),
|
||||
),
|
||||
("values", {"messages": [HumanMessage(content="hi", id="h-1"), assembled]}),
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(client, "_ensure_agent"),
|
||||
patch.object(client, "_agent", agent),
|
||||
):
|
||||
events = list(client.stream("hi", thread_id="t-stream-kwargs-delta"))
|
||||
|
||||
ai_events = [event for event in events if event.type == "messages-tuple" and event.data.get("type") == "ai" and event.data.get("id") == "ai-1"]
|
||||
metadata_events = [event for event in ai_events if event.data.get("additional_kwargs")]
|
||||
|
||||
assert metadata_events[0].data["additional_kwargs"] == {"reasoning_content": "Thinking first."}
|
||||
assert metadata_events[1].data["content"] == ""
|
||||
assert metadata_events[1].data["additional_kwargs"] == {"token_usage_attribution": attribution}
|
||||
|
||||
def test_chat_accumulates_streamed_deltas(self, client):
|
||||
"""chat() concatenates per-id deltas from messages mode."""
|
||||
agent = MagicMock()
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
"""Tests for DeerFlowClient message serialization helpers."""
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from deerflow.client import DeerFlowClient
|
||||
|
||||
|
||||
def test_serialize_ai_message_preserves_additional_kwargs():
|
||||
message = AIMessage(
|
||||
content="done",
|
||||
additional_kwargs={
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
},
|
||||
usage_metadata={"input_tokens": 12, "output_tokens": 3, "total_tokens": 15},
|
||||
)
|
||||
|
||||
serialized = DeerFlowClient._serialize_message(message)
|
||||
|
||||
assert serialized["type"] == "ai"
|
||||
assert serialized["usage_metadata"] == {
|
||||
"input_tokens": 12,
|
||||
"output_tokens": 3,
|
||||
"total_tokens": 15,
|
||||
}
|
||||
assert serialized["additional_kwargs"] == {
|
||||
"token_usage_attribution": {
|
||||
"version": 1,
|
||||
"kind": "final_answer",
|
||||
"shared_attribution": False,
|
||||
"actions": [],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def test_serialize_human_message_preserves_additional_kwargs():
|
||||
message = HumanMessage(
|
||||
content="hello",
|
||||
additional_kwargs={"files": [{"name": "diagram.png"}]},
|
||||
)
|
||||
|
||||
serialized = DeerFlowClient._serialize_message(message)
|
||||
|
||||
assert serialized == {
|
||||
"type": "human",
|
||||
"content": "hello",
|
||||
"id": None,
|
||||
"additional_kwargs": {"files": [{"name": "diagram.png"}]},
|
||||
}
|
||||
@@ -1,32 +1,157 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
"""Tests for TokenUsageMiddleware attribution annotations."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
|
||||
from deerflow.agents.middlewares.token_usage_middleware import TokenUsageMiddleware
|
||||
from deerflow.agents.middlewares.token_usage_middleware import (
|
||||
TOKEN_USAGE_ATTRIBUTION_KEY,
|
||||
TokenUsageMiddleware,
|
||||
)
|
||||
|
||||
|
||||
def test_after_model_logs_usage_metadata_counts():
|
||||
middleware = TokenUsageMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
AIMessage(
|
||||
content="done",
|
||||
usage_metadata={
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
},
|
||||
)
|
||||
def _make_runtime():
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "test-thread"}
|
||||
return runtime
|
||||
|
||||
|
||||
class TestTokenUsageMiddleware:
|
||||
def test_annotates_todo_updates_with_structured_actions(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "write_todos:1",
|
||||
"name": "write_todos",
|
||||
"args": {
|
||||
"todos": [
|
||||
{"content": "Inspect streaming path", "status": "completed"},
|
||||
{"content": "Design token attribution schema", "status": "in_progress"},
|
||||
]
|
||||
},
|
||||
}
|
||||
],
|
||||
usage_metadata={"input_tokens": 100, "output_tokens": 20, "total_tokens": 120},
|
||||
)
|
||||
|
||||
state = {
|
||||
"messages": [message],
|
||||
"todos": [
|
||||
{"content": "Inspect streaming path", "status": "in_progress"},
|
||||
{"content": "Design token attribution schema", "status": "pending"},
|
||||
],
|
||||
}
|
||||
|
||||
result = middleware.after_model(state, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
updated_message = result["messages"][0]
|
||||
attribution = updated_message.additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "tool_batch"
|
||||
assert attribution["shared_attribution"] is True
|
||||
assert attribution["tool_call_ids"] == ["write_todos:1"]
|
||||
assert attribution["actions"] == [
|
||||
{
|
||||
"kind": "todo_complete",
|
||||
"content": "Inspect streaming path",
|
||||
"tool_call_id": "write_todos:1",
|
||||
},
|
||||
{
|
||||
"kind": "todo_start",
|
||||
"content": "Design token attribution schema",
|
||||
"tool_call_id": "write_todos:1",
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
with patch("deerflow.agents.middlewares.token_usage_middleware.logger.info") as info_mock:
|
||||
result = middleware.after_model(state=state, runtime=MagicMock())
|
||||
def test_annotates_subagent_and_search_steps(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "task:1",
|
||||
"name": "task",
|
||||
"args": {
|
||||
"description": "spec-coder patch message grouping",
|
||||
"subagent_type": "general-purpose",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "web_search:1",
|
||||
"name": "web_search",
|
||||
"args": {"query": "LangGraph useStream messages tuple"},
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
assert result is None
|
||||
info_mock.assert_called_once_with(
|
||||
"LLM token usage: input=%s output=%s total=%s",
|
||||
10,
|
||||
5,
|
||||
15,
|
||||
)
|
||||
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "tool_batch"
|
||||
assert attribution["shared_attribution"] is True
|
||||
assert attribution["actions"] == [
|
||||
{
|
||||
"kind": "subagent",
|
||||
"description": "spec-coder patch message grouping",
|
||||
"subagent_type": "general-purpose",
|
||||
"tool_call_id": "task:1",
|
||||
},
|
||||
{
|
||||
"kind": "search",
|
||||
"tool_name": "web_search",
|
||||
"query": "LangGraph useStream messages tuple",
|
||||
"tool_call_id": "web_search:1",
|
||||
},
|
||||
]
|
||||
|
||||
def test_marks_final_answer_when_no_tools(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(content="Here is the final answer.")
|
||||
|
||||
result = middleware.after_model({"messages": [message]}, _make_runtime())
|
||||
|
||||
assert result is not None
|
||||
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "final_answer"
|
||||
assert attribution["shared_attribution"] is False
|
||||
assert attribution["actions"] == []
|
||||
|
||||
def test_annotates_removed_todos(self):
|
||||
middleware = TokenUsageMiddleware()
|
||||
message = AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"id": "write_todos:remove",
|
||||
"name": "write_todos",
|
||||
"args": {
|
||||
"todos": [],
|
||||
},
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
result = middleware.after_model(
|
||||
{
|
||||
"messages": [message],
|
||||
"todos": [
|
||||
{"content": "Archive obsolete plan", "status": "pending"},
|
||||
],
|
||||
},
|
||||
_make_runtime(),
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
attribution = result["messages"][0].additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY]
|
||||
assert attribution["kind"] == "todo_update"
|
||||
assert attribution["shared_attribution"] is False
|
||||
assert attribution["actions"] == [
|
||||
{
|
||||
"kind": "todo_remove",
|
||||
"content": "Archive obsolete plan",
|
||||
"tool_call_id": "write_todos:remove",
|
||||
}
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user