9892a7d468
* fix: bucket subagent token usage into RunRow.subagent_tokens Add caller-bucketed token tracking to RunJournal so subagent and middleware LLM calls are written to the correct RunRow columns instead of all falling into lead_agent_tokens (default 0). - RunJournal: accumulate _lead_agent_tokens / _subagent_tokens / _middleware_tokens in on_llm_end, deduped by langchain run_id. Add record_external_llm_usage_records() for external sources (respects track_token_usage flag). Return caller buckets from get_completion_data(). - SubagentTokenCollector: new lightweight callback handler that collects LLM usage within subagent execution. - SubagentExecutor: wire collector into subagent run_config and sync records to SubagentResult on every chunk (timeout/cancel safe). - SubagentResult: add token_usage_records and usage_reported fields. - task_tool: report subagent usage to parent RunJournal on every terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including the CancelledError path, guarded against double-reporting. No DB migration needed — RunRow columns already exist. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix: address token usage review feedback * Address review follow-ups --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
"""Callback handler that collects LLM token usage within a subagent.
|
|
|
|
Each subagent execution creates its own collector. After the subagent
|
|
finishes, the collected records are transferred to the parent RunJournal
|
|
via :meth:`RunJournal.record_external_llm_usage_records`.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from langchain_core.callbacks import BaseCallbackHandler
|
|
|
|
|
|
class SubagentTokenCollector(BaseCallbackHandler):
|
|
"""Lightweight callback handler that collects LLM token usage within a subagent."""
|
|
|
|
def __init__(self, caller: str):
|
|
super().__init__()
|
|
self.caller = caller
|
|
self._records: list[dict[str, int | str]] = []
|
|
self._counted_run_ids: set[str] = set()
|
|
|
|
def on_llm_end(
|
|
self,
|
|
response: Any,
|
|
*,
|
|
run_id: Any,
|
|
tags: list[str] | None = None,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
rid = str(run_id)
|
|
if rid in self._counted_run_ids:
|
|
return
|
|
|
|
for generation in response.generations:
|
|
for gen in generation:
|
|
if not hasattr(gen, "message"):
|
|
continue
|
|
usage = getattr(gen.message, "usage_metadata", None)
|
|
usage_dict = dict(usage) if usage else {}
|
|
input_tk = usage_dict.get("input_tokens", 0) or 0
|
|
output_tk = usage_dict.get("output_tokens", 0) or 0
|
|
total_tk = usage_dict.get("total_tokens", 0) or 0
|
|
if total_tk <= 0:
|
|
total_tk = input_tk + output_tk
|
|
if total_tk <= 0:
|
|
continue
|
|
self._counted_run_ids.add(rid)
|
|
self._records.append(
|
|
{
|
|
"source_run_id": rid,
|
|
"caller": self.caller,
|
|
"input_tokens": input_tk,
|
|
"output_tokens": output_tk,
|
|
"total_tokens": total_tk,
|
|
}
|
|
)
|
|
return
|
|
|
|
def snapshot_records(self) -> list[dict[str, int | str]]:
|
|
"""Return a copy of the accumulated usage records."""
|
|
return list(self._records)
|