mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
fix: bucket subagent token usage into parent run totals (#2838)
* fix: bucket subagent token usage into RunRow.subagent_tokens Add caller-bucketed token tracking to RunJournal so subagent and middleware LLM calls are written to the correct RunRow columns instead of all falling into lead_agent_tokens (default 0). - RunJournal: accumulate _lead_agent_tokens / _subagent_tokens / _middleware_tokens in on_llm_end, deduped by langchain run_id. Add record_external_llm_usage_records() for external sources (respects track_token_usage flag). Return caller buckets from get_completion_data(). - SubagentTokenCollector: new lightweight callback handler that collects LLM usage within subagent execution. - SubagentExecutor: wire collector into subagent run_config and sync records to SubagentResult on every chunk (timeout/cancel safe). - SubagentResult: add token_usage_records and usage_reported fields. - task_tool: report subagent usage to parent RunJournal on every terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including the CancelledError path, guarded against double-reporting. No DB migration needed — RunRow columns already exist. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix: address token usage review feedback * Address review follow-ups --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -63,6 +63,15 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._total_tokens = 0
|
||||
self._llm_call_count = 0
|
||||
|
||||
# Caller-bucketed token accumulators
|
||||
self._lead_agent_tokens = 0
|
||||
self._subagent_tokens = 0
|
||||
self._middleware_tokens = 0
|
||||
|
||||
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
|
||||
self._counted_llm_run_ids: set[str] = set()
|
||||
self._counted_external_source_ids: set[str] = set()
|
||||
|
||||
# Convenience fields
|
||||
self._last_ai_msg: str | None = None
|
||||
self._first_human_msg: str | None = None
|
||||
@@ -214,19 +223,28 @@ class RunJournal(BaseCallbackHandler):
|
||||
},
|
||||
)
|
||||
|
||||
# Token accumulation
|
||||
# Token accumulation (dedup by langchain run_id to avoid double-counting
|
||||
# when the callback fires more than once for the same response)
|
||||
if self._track_tokens:
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk == 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk > 0:
|
||||
if total_tk > 0 and rid not in self._counted_llm_run_ids:
|
||||
self._counted_llm_run_ids.add(rid)
|
||||
self._total_input_tokens += input_tk
|
||||
self._total_output_tokens += output_tk
|
||||
self._total_tokens += total_tk
|
||||
self._llm_call_count += 1
|
||||
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||
self._llm_start_times.pop(str(run_id), None)
|
||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||
@@ -330,6 +348,49 @@ class RunJournal(BaseCallbackHandler):
|
||||
|
||||
# -- Public methods (called by worker) --
|
||||
|
||||
def record_external_llm_usage_records(
|
||||
self,
|
||||
records: list[dict[str, int | str]],
|
||||
) -> None:
|
||||
"""Record token usage from external sources (e.g., subagents).
|
||||
|
||||
Each record should contain:
|
||||
source_run_id: Unique identifier to prevent double-counting
|
||||
caller: Caller tag (e.g. "subagent:general-purpose")
|
||||
input_tokens: Input token count
|
||||
output_tokens: Output token count
|
||||
total_tokens: Total token count (computed from input+output if 0/missing)
|
||||
"""
|
||||
if not self._track_tokens:
|
||||
return
|
||||
for record in records:
|
||||
source_id = str(record.get("source_run_id", ""))
|
||||
if not source_id:
|
||||
continue
|
||||
if source_id in self._counted_external_source_ids:
|
||||
continue
|
||||
|
||||
total_tk = record.get("total_tokens", 0) or 0
|
||||
if total_tk <= 0:
|
||||
input_tk = record.get("input_tokens", 0) or 0
|
||||
output_tk = record.get("output_tokens", 0) or 0
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk <= 0:
|
||||
continue
|
||||
|
||||
self._counted_external_source_ids.add(source_id)
|
||||
self._total_input_tokens += record.get("input_tokens", 0) or 0
|
||||
self._total_output_tokens += record.get("output_tokens", 0) or 0
|
||||
self._total_tokens += total_tk
|
||||
|
||||
caller = str(record.get("caller", ""))
|
||||
if caller.startswith("subagent:"):
|
||||
self._subagent_tokens += total_tk
|
||||
elif caller.startswith("middleware:"):
|
||||
self._middleware_tokens += total_tk
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
@@ -376,6 +437,9 @@ class RunJournal(BaseCallbackHandler):
|
||||
"total_output_tokens": self._total_output_tokens,
|
||||
"total_tokens": self._total_tokens,
|
||||
"llm_call_count": self._llm_call_count,
|
||||
"lead_agent_tokens": self._lead_agent_tokens,
|
||||
"subagent_tokens": self._subagent_tokens,
|
||||
"middleware_tokens": self._middleware_tokens,
|
||||
"message_count": self._msg_count,
|
||||
"last_ai_message": self._last_ai_msg,
|
||||
"first_human_message": self._first_human_msg,
|
||||
|
||||
@@ -26,6 +26,7 @@ from deerflow.models import create_chat_model
|
||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||
from deerflow.skills.types import Skill
|
||||
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
|
||||
from deerflow.subagents.token_collector import SubagentTokenCollector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -70,6 +71,8 @@ class SubagentResult:
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
ai_messages: list[dict[str, Any]] | None = None
|
||||
token_usage_records: list[dict[str, int | str]] = field(default_factory=list)
|
||||
usage_reported: bool = False
|
||||
cancel_event: threading.Event = field(default_factory=threading.Event, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
@@ -412,13 +415,20 @@ class SubagentExecutor:
|
||||
ai_messages = []
|
||||
result.ai_messages = ai_messages
|
||||
|
||||
collector: SubagentTokenCollector | None = None
|
||||
try:
|
||||
state, filtered_tools = await self._build_initial_state(task)
|
||||
agent = self._create_agent(filtered_tools)
|
||||
|
||||
# Token collector for subagent LLM calls
|
||||
collector_caller = f"subagent:{self.config.name}"
|
||||
collector = SubagentTokenCollector(caller=collector_caller)
|
||||
|
||||
# Build config with thread_id for sandbox access and recursion limit
|
||||
run_config: RunnableConfig = {
|
||||
"recursion_limit": self.config.max_turns,
|
||||
"callbacks": [collector],
|
||||
"tags": [collector_caller],
|
||||
}
|
||||
context: dict[str, Any] = {}
|
||||
if self.thread_id:
|
||||
@@ -441,6 +451,8 @@ class SubagentExecutor:
|
||||
result.status = SubagentStatus.CANCELLED
|
||||
result.error = "Cancelled by user"
|
||||
result.completed_at = datetime.now()
|
||||
if collector is not None:
|
||||
result.token_usage_records = collector.snapshot_records()
|
||||
return result
|
||||
|
||||
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
||||
@@ -455,6 +467,7 @@ class SubagentExecutor:
|
||||
result.status = SubagentStatus.CANCELLED
|
||||
result.error = "Cancelled by user"
|
||||
result.completed_at = datetime.now()
|
||||
result.token_usage_records = collector.snapshot_records()
|
||||
return result
|
||||
|
||||
final_state = chunk
|
||||
@@ -481,6 +494,7 @@ class SubagentExecutor:
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}")
|
||||
|
||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
|
||||
result.token_usage_records = collector.snapshot_records()
|
||||
|
||||
if final_state is None:
|
||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
|
||||
@@ -560,6 +574,8 @@ class SubagentExecutor:
|
||||
result.status = SubagentStatus.FAILED
|
||||
result.error = str(e)
|
||||
result.completed_at = datetime.now()
|
||||
if collector is not None:
|
||||
result.token_usage_records = collector.snapshot_records()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
"""Callback handler that collects LLM token usage within a subagent.
|
||||
|
||||
Each subagent execution creates its own collector. After the subagent
|
||||
finishes, the collected records are transferred to the parent RunJournal
|
||||
via :meth:`RunJournal.record_external_llm_usage_records`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.callbacks import BaseCallbackHandler
|
||||
|
||||
|
||||
class SubagentTokenCollector(BaseCallbackHandler):
|
||||
"""Lightweight callback handler that collects LLM token usage within a subagent."""
|
||||
|
||||
def __init__(self, caller: str):
|
||||
super().__init__()
|
||||
self.caller = caller
|
||||
self._records: list[dict[str, int | str]] = []
|
||||
self._counted_run_ids: set[str] = set()
|
||||
|
||||
def on_llm_end(
|
||||
self,
|
||||
response: Any,
|
||||
*,
|
||||
run_id: Any,
|
||||
tags: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
rid = str(run_id)
|
||||
if rid in self._counted_run_ids:
|
||||
return
|
||||
|
||||
for generation in response.generations:
|
||||
for gen in generation:
|
||||
if not hasattr(gen, "message"):
|
||||
continue
|
||||
usage = getattr(gen.message, "usage_metadata", None)
|
||||
usage_dict = dict(usage) if usage else {}
|
||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||
if total_tk <= 0:
|
||||
total_tk = input_tk + output_tk
|
||||
if total_tk <= 0:
|
||||
continue
|
||||
self._counted_run_ids.add(rid)
|
||||
self._records.append(
|
||||
{
|
||||
"source_run_id": rid,
|
||||
"caller": self.caller,
|
||||
"input_tokens": input_tk,
|
||||
"output_tokens": output_tk,
|
||||
"total_tokens": total_tk,
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
def snapshot_records(self) -> list[dict[str, int | str]]:
|
||||
"""Return a copy of the accumulated usage records."""
|
||||
return list(self._records)
|
||||
@@ -27,6 +27,92 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _is_subagent_terminal(result: Any) -> bool:
|
||||
"""Return whether a background subagent result is safe to clean up."""
|
||||
return result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None
|
||||
|
||||
|
||||
async def _await_subagent_terminal(task_id: str, max_polls: int) -> Any | None:
|
||||
"""Poll until the background subagent reaches a terminal status or we run out of polls."""
|
||||
for _ in range(max_polls):
|
||||
result = get_background_task_result(task_id)
|
||||
if result is None:
|
||||
return None
|
||||
if _is_subagent_terminal(result):
|
||||
return result
|
||||
await asyncio.sleep(5)
|
||||
return None
|
||||
|
||||
|
||||
async def _deferred_cleanup_subagent_task(task_id: str, trace_id: str, max_polls: int) -> None:
|
||||
"""Keep polling a cancelled subagent until it can be safely removed."""
|
||||
cleanup_poll_count = 0
|
||||
while True:
|
||||
result = get_background_task_result(task_id)
|
||||
if result is None:
|
||||
return
|
||||
if _is_subagent_terminal(result):
|
||||
cleanup_background_task(task_id)
|
||||
return
|
||||
if cleanup_poll_count >= max_polls:
|
||||
logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls")
|
||||
return
|
||||
await asyncio.sleep(5)
|
||||
cleanup_poll_count += 1
|
||||
|
||||
|
||||
def _log_cleanup_failure(cleanup_task: asyncio.Task[None], *, trace_id: str, task_id: str) -> None:
|
||||
if cleanup_task.cancelled():
|
||||
return
|
||||
|
||||
exc = cleanup_task.exception()
|
||||
if exc is not None:
|
||||
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
|
||||
|
||||
|
||||
def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls: int) -> None:
|
||||
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
|
||||
cleanup_task = asyncio.create_task(_deferred_cleanup_subagent_task(task_id, trace_id, max_polls))
|
||||
cleanup_task.add_done_callback(lambda task: _log_cleanup_failure(task, trace_id=trace_id, task_id=task_id))
|
||||
|
||||
|
||||
def _find_usage_recorder(runtime: Any) -> Any | None:
|
||||
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config."""
|
||||
if runtime is None:
|
||||
return None
|
||||
config = getattr(runtime, "config", None)
|
||||
if not isinstance(config, dict):
|
||||
return None
|
||||
callbacks = config.get("callbacks", [])
|
||||
if not callbacks:
|
||||
return None
|
||||
for cb in callbacks:
|
||||
if hasattr(cb, "record_external_llm_usage_records"):
|
||||
return cb
|
||||
return None
|
||||
|
||||
|
||||
def _report_subagent_usage(runtime: Any, result: Any) -> None:
|
||||
"""Report subagent token usage to the parent RunJournal, if available.
|
||||
|
||||
Each subagent task must be reported only once (guarded by usage_reported).
|
||||
"""
|
||||
if getattr(result, "usage_reported", True):
|
||||
return
|
||||
records = getattr(result, "token_usage_records", None) or []
|
||||
if not records:
|
||||
return
|
||||
journal = _find_usage_recorder(runtime)
|
||||
if journal is None:
|
||||
logger.debug("No usage recorder found in runtime callbacks — subagent token usage not recorded")
|
||||
return
|
||||
try:
|
||||
journal.record_external_llm_usage_records(records)
|
||||
result.usage_reported = True
|
||||
except Exception:
|
||||
logger.warning("Failed to report subagent token usage", exc_info=True)
|
||||
|
||||
|
||||
def _get_runtime_app_config(runtime: Any) -> "AppConfig | None":
|
||||
context = getattr(runtime, "context", None)
|
||||
if isinstance(context, dict):
|
||||
@@ -227,21 +313,25 @@ async def task_tool(
|
||||
|
||||
# Check if task completed, failed, or timed out
|
||||
if result.status == SubagentStatus.COMPLETED:
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task Succeeded. Result: {result.result}"
|
||||
elif result.status == SubagentStatus.FAILED:
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return f"Task failed. Error: {result.error}"
|
||||
elif result.status == SubagentStatus.CANCELLED:
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
|
||||
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
return "Task cancelled by user."
|
||||
elif result.status == SubagentStatus.TIMED_OUT:
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||
cleanup_background_task(task_id)
|
||||
@@ -260,43 +350,28 @@ async def task_tool(
|
||||
if poll_count > max_poll_count:
|
||||
timeout_minutes = config.timeout_seconds // 60
|
||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||
_report_subagent_usage(runtime, result)
|
||||
writer({"type": "task_timed_out", "task_id": task_id})
|
||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||
except asyncio.CancelledError:
|
||||
# Signal the background subagent thread to stop cooperatively.
|
||||
# Without this, the thread (running in ThreadPoolExecutor with its
|
||||
# own event loop via asyncio.run) would continue executing even
|
||||
# after the parent task is cancelled.
|
||||
request_cancel_background_task(task_id)
|
||||
|
||||
async def cleanup_when_done() -> None:
|
||||
max_cleanup_polls = max_poll_count
|
||||
cleanup_poll_count = 0
|
||||
# Wait (shielded) for the subagent to reach a terminal state so the
|
||||
# final token usage snapshot is reported to the parent RunJournal
|
||||
# before the parent worker persists get_completion_data().
|
||||
terminal_result = None
|
||||
try:
|
||||
terminal_result = await asyncio.shield(_await_subagent_terminal(task_id, max_poll_count))
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
while True:
|
||||
result = get_background_task_result(task_id)
|
||||
if result is None:
|
||||
return
|
||||
|
||||
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
|
||||
cleanup_background_task(task_id)
|
||||
return
|
||||
|
||||
if cleanup_poll_count > max_cleanup_polls:
|
||||
logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls")
|
||||
return
|
||||
|
||||
await asyncio.sleep(5)
|
||||
cleanup_poll_count += 1
|
||||
|
||||
def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None:
|
||||
if cleanup_task.cancelled():
|
||||
return
|
||||
|
||||
exc = cleanup_task.exception()
|
||||
if exc is not None:
|
||||
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
|
||||
|
||||
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
|
||||
asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure)
|
||||
# Report whatever the subagent collected (even if we timed out).
|
||||
final_result = terminal_result or get_background_task_result(task_id)
|
||||
if final_result is not None:
|
||||
_report_subagent_usage(runtime, final_result)
|
||||
if final_result is not None and _is_subagent_terminal(final_result):
|
||||
cleanup_background_task(task_id)
|
||||
else:
|
||||
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user