mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
fix: bucket subagent token usage into parent run totals (#2838)
* fix: bucket subagent token usage into RunRow.subagent_tokens Add caller-bucketed token tracking to RunJournal so subagent and middleware LLM calls are written to the correct RunRow columns instead of all falling into lead_agent_tokens (default 0). - RunJournal: accumulate _lead_agent_tokens / _subagent_tokens / _middleware_tokens in on_llm_end, deduped by langchain run_id. Add record_external_llm_usage_records() for external sources (respects track_token_usage flag). Return caller buckets from get_completion_data(). - SubagentTokenCollector: new lightweight callback handler that collects LLM usage within subagent execution. - SubagentExecutor: wire collector into subagent run_config and sync records to SubagentResult on every chunk (timeout/cancel safe). - SubagentResult: add token_usage_records and usage_reported fields. - task_tool: report subagent usage to parent RunJournal on every terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including the CancelledError path, guarded against double-reporting. No DB migration needed — RunRow columns already exist. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix: address token usage review feedback * Address review follow-ups --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -383,6 +383,244 @@ class TestMiddlewareEvents:
|
||||
assert "middleware:guardrail" in event_types
|
||||
|
||||
|
||||
class TestCallerBucketing:
|
||||
"""Tests for caller-bucketed token accumulation (lead_agent / subagent / middleware)."""
|
||||
|
||||
def test_lead_agent_bucketing(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
assert j._lead_agent_tokens == 15
|
||||
assert j._subagent_tokens == 0
|
||||
assert j._middleware_tokens == 0
|
||||
|
||||
def test_subagent_bucketing(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
usage = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
||||
assert j._subagent_tokens == 30
|
||||
assert j._lead_agent_tokens == 0
|
||||
assert j._middleware_tokens == 0
|
||||
|
||||
def test_middleware_bucketing(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
usage = {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7}
|
||||
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarize"])
|
||||
assert j._middleware_tokens == 7
|
||||
assert j._lead_agent_tokens == 0
|
||||
assert j._subagent_tokens == 0
|
||||
|
||||
def test_mixed_callers_sum_independently(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:bash"])
|
||||
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:title"])
|
||||
assert j._lead_agent_tokens == 15
|
||||
assert j._subagent_tokens == 15
|
||||
assert j._middleware_tokens == 15
|
||||
assert j._total_tokens == 45
|
||||
|
||||
def test_get_completion_data_includes_buckets(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
j._lead_agent_tokens = 100
|
||||
j._subagent_tokens = 200
|
||||
j._middleware_tokens = 50
|
||||
data = j.get_completion_data()
|
||||
assert data["lead_agent_tokens"] == 100
|
||||
assert data["subagent_tokens"] == 200
|
||||
assert data["middleware_tokens"] == 50
|
||||
|
||||
def test_dedup_same_run_id(self, journal_setup):
|
||||
"""Same langchain run_id in on_llm_end must not double-count."""
|
||||
j, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||
assert j._total_tokens == 15
|
||||
assert j._lead_agent_tokens == 15
|
||||
assert j._llm_call_count == 1
|
||||
|
||||
def test_first_no_usage_second_with_usage(self, journal_setup):
|
||||
"""First callback with no usage must not block second callback with usage for same run_id."""
|
||||
j, _ = journal_setup
|
||||
run_id = uuid4()
|
||||
j.on_llm_end(_make_llm_response("A", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||
assert str(run_id) not in j._counted_llm_run_ids
|
||||
# Second callback for the same run_id with actual usage must still count
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||
assert j._total_tokens == 15
|
||||
assert j._lead_agent_tokens == 15
|
||||
|
||||
def test_track_token_usage_false_skips_buckets(self):
|
||||
"""When token tracking is disabled, caller buckets stay at 0."""
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("X", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
||||
assert j._subagent_tokens == 0
|
||||
assert j._lead_agent_tokens == 0
|
||||
|
||||
def test_default_no_tags_buckets_as_lead_agent(self, journal_setup):
|
||||
"""LLM calls without explicit tags default to lead_agent bucket."""
|
||||
j, _ = journal_setup
|
||||
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
|
||||
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None)
|
||||
assert j._lead_agent_tokens == 10
|
||||
assert j._subagent_tokens == 0
|
||||
assert j._middleware_tokens == 0
|
||||
|
||||
def test_unknown_tag_buckets_as_lead_agent(self, journal_setup):
|
||||
"""Calls with unrecognized tags (not lead_agent/subagent:/middleware:) go to lead_agent."""
|
||||
j, _ = journal_setup
|
||||
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
|
||||
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["some_random_tag"])
|
||||
assert j._lead_agent_tokens == 10
|
||||
|
||||
|
||||
class TestExternalUsageRecords:
|
||||
"""Tests for record_external_llm_usage_records."""
|
||||
|
||||
def test_records_added_to_subagent_bucket(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
records = [
|
||||
{
|
||||
"source_run_id": "ext-1",
|
||||
"caller": "subagent:general-purpose",
|
||||
"input_tokens": 100,
|
||||
"output_tokens": 50,
|
||||
"total_tokens": 150,
|
||||
}
|
||||
]
|
||||
j.record_external_llm_usage_records(records)
|
||||
assert j._subagent_tokens == 150
|
||||
assert j._total_tokens == 150
|
||||
assert j._total_input_tokens == 100
|
||||
assert j._total_output_tokens == 50
|
||||
|
||||
def test_records_added_to_middleware_bucket(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
records = [
|
||||
{
|
||||
"source_run_id": "ext-2",
|
||||
"caller": "middleware:summarize",
|
||||
"input_tokens": 30,
|
||||
"output_tokens": 10,
|
||||
"total_tokens": 40,
|
||||
}
|
||||
]
|
||||
j.record_external_llm_usage_records(records)
|
||||
assert j._middleware_tokens == 40
|
||||
assert j._lead_agent_tokens == 0
|
||||
assert j._subagent_tokens == 0
|
||||
|
||||
def test_records_added_to_lead_agent_bucket(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
records = [
|
||||
{
|
||||
"source_run_id": "ext-3",
|
||||
"caller": "lead_agent",
|
||||
"input_tokens": 10,
|
||||
"output_tokens": 5,
|
||||
"total_tokens": 15,
|
||||
}
|
||||
]
|
||||
j.record_external_llm_usage_records(records)
|
||||
assert j._lead_agent_tokens == 15
|
||||
|
||||
def test_dedup_same_source_run_id(self, journal_setup):
|
||||
"""Same source_run_id must not be double-counted."""
|
||||
j, _ = journal_setup
|
||||
records = [
|
||||
{
|
||||
"source_run_id": "dup-1",
|
||||
"caller": "subagent:research",
|
||||
"input_tokens": 50,
|
||||
"output_tokens": 25,
|
||||
"total_tokens": 75,
|
||||
}
|
||||
]
|
||||
j.record_external_llm_usage_records(records)
|
||||
j.record_external_llm_usage_records(records)
|
||||
assert j._subagent_tokens == 75
|
||||
assert j._total_tokens == 75
|
||||
|
||||
def test_total_tokens_missing_computed_from_input_output(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
records = [
|
||||
{
|
||||
"source_run_id": "ext-4",
|
||||
"caller": "subagent:bash",
|
||||
"input_tokens": 200,
|
||||
"output_tokens": 100,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
]
|
||||
j.record_external_llm_usage_records(records)
|
||||
assert j._subagent_tokens == 300
|
||||
assert j._total_tokens == 300
|
||||
|
||||
def test_total_tokens_zero_no_count(self, journal_setup):
|
||||
"""Records with zero total and zero input+output must not be counted."""
|
||||
j, _ = journal_setup
|
||||
records = [
|
||||
{
|
||||
"source_run_id": "ext-5",
|
||||
"caller": "subagent:research",
|
||||
"input_tokens": 0,
|
||||
"output_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
]
|
||||
j.record_external_llm_usage_records(records)
|
||||
assert j._total_tokens == 0
|
||||
assert j._subagent_tokens == 0
|
||||
|
||||
def test_empty_source_run_id_skipped(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
records = [
|
||||
{
|
||||
"source_run_id": "",
|
||||
"caller": "subagent:research",
|
||||
"input_tokens": 50,
|
||||
"output_tokens": 25,
|
||||
"total_tokens": 75,
|
||||
}
|
||||
]
|
||||
j.record_external_llm_usage_records(records)
|
||||
assert j._total_tokens == 0
|
||||
|
||||
def test_multiple_records_in_single_call(self, journal_setup):
|
||||
j, _ = journal_setup
|
||||
records = [
|
||||
{"source_run_id": "r1", "caller": "subagent:gp", "input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||
{"source_run_id": "r2", "caller": "subagent:bash", "input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
|
||||
]
|
||||
j.record_external_llm_usage_records(records)
|
||||
assert j._subagent_tokens == 45
|
||||
assert j._total_tokens == 45
|
||||
|
||||
def test_external_records_coexist_with_inline_callbacks(self, journal_setup):
|
||||
"""External records and inline on_llm_end must not interfere."""
|
||||
j, _ = journal_setup
|
||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||
j.record_external_llm_usage_records([{"source_run_id": "ext-6", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
|
||||
assert j._lead_agent_tokens == 15
|
||||
assert j._subagent_tokens == 150
|
||||
assert j._total_tokens == 165
|
||||
|
||||
def test_track_token_usage_false_skips_external_records(self):
|
||||
"""When token tracking is disabled, external records must not accumulate."""
|
||||
store = MemoryRunEventStore()
|
||||
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
|
||||
j.record_external_llm_usage_records([{"source_run_id": "ext-7", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
|
||||
assert j._total_tokens == 0
|
||||
assert j._subagent_tokens == 0
|
||||
|
||||
|
||||
class TestChatModelStartHumanMessage:
|
||||
"""Tests for on_chat_model_start extracting the first human message."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user