mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
feat: stream subagent token usage to header via terminal task events (#2882)
* feat: real-time subagent token usage display in header and per-turn Backend: - Persist subagent token usage to AIMessage.usage_metadata via TokenUsageMiddleware, so accumulateUsage() naturally includes subagent tokens without frontend state management - Cache subagent usage by tool_call_id in task_tool, write back to the dispatching AIMessage on next model response - Emit subagent token usage on all terminal task events (task_completed, task_failed, task_cancelled, task_timed_out) - Report subagent usage to parent RunJournal for API totals - Search backward from ToolMessage to find dispatching AIMessage for correct multi-tool-call attribution Frontend: - Remove subagentUsage state, custom event handling, and prop threading — subagent tokens are now embedded in message metadata - Simplify selectHeaderTokenUsage (no subagentUsage parameter) - Per-turn inline badges show turn-specific usage via message accumulation - Remove isLoading guard from MessageTokenUsageList for dynamic updates during streaming * fix: prevent header token double counting from baseline reset race onFinish, onError, and thread-switch useEffect all reset pendingUsageBaselineMessageIdsRef to an empty Set. If thread.isLoading is still true on the next render, all messages pass the getMessagesAfterBaseline filter and their tokens are added to backendUsage (which already includes them), causing the header to display up to 2× the actual token count. Capture current message IDs instead of using an empty Set so that getMessagesAfterBaseline correctly returns no pending messages even if thread.isLoading lags behind the stream end. * fix: write back subagent tokens for all concurrent task tool calls TokenUsageMiddleware only processed messages[-2], so when a single model response dispatched multiple task tool calls only the last ToolMessage had its cached subagent usage written back to the dispatch AIMessage.usage_metadata. Earlier tasks' usage stayed in _subagent_usage_cache indefinitely (leak) and never appeared in the per-turn inline token display. Walk backward through all consecutive ToolMessages before the new AIMessage, and accumulate updates targeting the same dispatch message into one state update so overlapping writes don't clobber each other. * fix: clean up subagent usage cache entry on task cancellation When a task_tool invocation is cancelled via CancelledError, any cached subagent usage entry leaked because the TokenUsageMiddleware writeback path never fires after cancellation. Pop the cache entry before re-raising to prevent unbounded growth of the module-level _subagent_usage_cache dict. * fix: address token usage review feedback * fix: handle missing config for subagent usage cache --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -9,7 +9,7 @@ from typing import Any, override
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import AgentMiddleware
|
||||
from langchain.agents.middleware.todo import Todo
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
|
||||
return "thinking"
|
||||
|
||||
|
||||
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
|
||||
"""Return True if the AIMessage contains a tool_call with the given id."""
|
||||
for tc in message.tool_calls or []:
|
||||
if isinstance(tc, dict):
|
||||
if tc.get("id") == tool_call_id:
|
||||
return True
|
||||
elif hasattr(tc, "id") and tc.id == tool_call_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
|
||||
tool_calls = getattr(message, "tool_calls", None) or []
|
||||
actions: list[dict[str, Any]] = []
|
||||
@@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
if not messages:
|
||||
return None
|
||||
|
||||
# Annotate subagent token usage onto the AIMessage that dispatched it.
|
||||
# When a task tool completes, its usage is cached by tool_call_id. Detect
|
||||
# the ToolMessage → search backward for the corresponding AIMessage → merge.
|
||||
# Walk backward through consecutive ToolMessages before the new AIMessage
|
||||
# so that multiple concurrent task tool calls all get their subagent tokens
|
||||
# written back to the same dispatch message (merging into one update).
|
||||
state_updates: dict[int, AIMessage] = {}
|
||||
if len(messages) >= 2:
|
||||
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
|
||||
|
||||
idx = len(messages) - 2
|
||||
while idx >= 0:
|
||||
tool_msg = messages[idx]
|
||||
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
|
||||
break
|
||||
|
||||
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
|
||||
if subagent_usage:
|
||||
# Search backward from the ToolMessage to find the AIMessage
|
||||
# that dispatched it. A single model response can dispatch
|
||||
# multiple task tool calls, so we can't assume a fixed offset.
|
||||
dispatch_idx = idx - 1
|
||||
while dispatch_idx >= 0:
|
||||
candidate = messages[dispatch_idx]
|
||||
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
|
||||
# Accumulate into an existing update for the same
|
||||
# AIMessage (multiple task calls in one response),
|
||||
# or merge fresh from the original message.
|
||||
existing_update = state_updates.get(dispatch_idx)
|
||||
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
|
||||
merged = {
|
||||
**prev,
|
||||
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
|
||||
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
|
||||
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
|
||||
}
|
||||
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
|
||||
break
|
||||
dispatch_idx -= 1
|
||||
idx -= 1
|
||||
|
||||
last = messages[-1]
|
||||
if not isinstance(last, AIMessage):
|
||||
if state_updates:
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
return None
|
||||
|
||||
usage = getattr(last, "usage_metadata", None)
|
||||
@@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware):
|
||||
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
|
||||
|
||||
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
|
||||
return None
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
|
||||
|
||||
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
|
||||
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
|
||||
return {"messages": [updated_msg]}
|
||||
state_updates[len(messages) - 1] = updated_msg
|
||||
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
|
||||
|
||||
@override
|
||||
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
|
||||
|
||||
Reference in New Issue
Block a user