fix: bucket subagent token usage into parent run totals (#2838)

* fix: bucket subagent token usage into RunRow.subagent_tokens

Add caller-bucketed token tracking to RunJournal so subagent and
middleware LLM calls are written to the correct RunRow columns instead
of all falling into lead_agent_tokens (default 0).

- RunJournal: accumulate _lead_agent_tokens / _subagent_tokens /
  _middleware_tokens in on_llm_end, deduped by langchain run_id.
  Add record_external_llm_usage_records() for external sources
  (respects track_token_usage flag). Return caller buckets from
  get_completion_data().
- SubagentTokenCollector: new lightweight callback handler that
  collects LLM usage within subagent execution.
- SubagentExecutor: wire collector into subagent run_config and sync
  records to SubagentResult on every chunk (timeout/cancel safe).
- SubagentResult: add token_usage_records and usage_reported fields.
- task_tool: report subagent usage to parent RunJournal on every
  terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including
  the CancelledError path, guarded against double-reporting.

No DB migration needed — RunRow columns already exist.

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix: address token usage review feedback

* Address review follow-ups

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
YuJitang
2026-05-10 22:47:30 +08:00
committed by GitHub
parent 94da8f67d7
commit 9892a7d468
8 changed files with 843 additions and 77 deletions
@@ -63,6 +63,15 @@ class RunJournal(BaseCallbackHandler):
self._total_tokens = 0
self._llm_call_count = 0
# Caller-bucketed token accumulators
self._lead_agent_tokens = 0
self._subagent_tokens = 0
self._middleware_tokens = 0
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
self._counted_llm_run_ids: set[str] = set()
self._counted_external_source_ids: set[str] = set()
# Convenience fields
self._last_ai_msg: str | None = None
self._first_human_msg: str | None = None
@@ -214,19 +223,28 @@ class RunJournal(BaseCallbackHandler):
},
)
# Token accumulation
# Token accumulation (dedup by langchain run_id to avoid double-counting
# when the callback fires more than once for the same response)
if self._track_tokens:
input_tk = usage_dict.get("input_tokens", 0) or 0
output_tk = usage_dict.get("output_tokens", 0) or 0
total_tk = usage_dict.get("total_tokens", 0) or 0
if total_tk == 0:
total_tk = input_tk + output_tk
if total_tk > 0:
if total_tk > 0 and rid not in self._counted_llm_run_ids:
self._counted_llm_run_ids.add(rid)
self._total_input_tokens += input_tk
self._total_output_tokens += output_tk
self._total_tokens += total_tk
self._llm_call_count += 1
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
self._llm_start_times.pop(str(run_id), None)
self._put(event_type="llm.error", category="trace", content=str(error))
@@ -330,6 +348,49 @@ class RunJournal(BaseCallbackHandler):
# -- Public methods (called by worker) --
def record_external_llm_usage_records(
self,
records: list[dict[str, int | str]],
) -> None:
"""Record token usage from external sources (e.g., subagents).
Each record should contain:
source_run_id: Unique identifier to prevent double-counting
caller: Caller tag (e.g. "subagent:general-purpose")
input_tokens: Input token count
output_tokens: Output token count
total_tokens: Total token count (computed from input+output if 0/missing)
"""
if not self._track_tokens:
return
for record in records:
source_id = str(record.get("source_run_id", ""))
if not source_id:
continue
if source_id in self._counted_external_source_ids:
continue
total_tk = record.get("total_tokens", 0) or 0
if total_tk <= 0:
input_tk = record.get("input_tokens", 0) or 0
output_tk = record.get("output_tokens", 0) or 0
total_tk = input_tk + output_tk
if total_tk <= 0:
continue
self._counted_external_source_ids.add(source_id)
self._total_input_tokens += record.get("input_tokens", 0) or 0
self._total_output_tokens += record.get("output_tokens", 0) or 0
self._total_tokens += total_tk
caller = str(record.get("caller", ""))
if caller.startswith("subagent:"):
self._subagent_tokens += total_tk
elif caller.startswith("middleware:"):
self._middleware_tokens += total_tk
else:
self._lead_agent_tokens += total_tk
def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
@@ -376,6 +437,9 @@ class RunJournal(BaseCallbackHandler):
"total_output_tokens": self._total_output_tokens,
"total_tokens": self._total_tokens,
"llm_call_count": self._llm_call_count,
"lead_agent_tokens": self._lead_agent_tokens,
"subagent_tokens": self._subagent_tokens,
"middleware_tokens": self._middleware_tokens,
"message_count": self._msg_count,
"last_ai_message": self._last_ai_msg,
"first_human_message": self._first_human_msg,