fix: bucket subagent token usage into parent run totals (#2838)

* fix: bucket subagent token usage into RunRow.subagent_tokens

Add caller-bucketed token tracking to RunJournal so subagent and
middleware LLM calls are written to the correct RunRow columns instead
of all falling into lead_agent_tokens (default 0).

- RunJournal: accumulate _lead_agent_tokens / _subagent_tokens /
  _middleware_tokens in on_llm_end, deduped by langchain run_id.
  Add record_external_llm_usage_records() for external sources
  (respects track_token_usage flag). Return caller buckets from
  get_completion_data().
- SubagentTokenCollector: new lightweight callback handler that
  collects LLM usage within subagent execution.
- SubagentExecutor: wire collector into subagent run_config and sync
  records to SubagentResult on every chunk (timeout/cancel safe).
- SubagentResult: add token_usage_records and usage_reported fields.
- task_tool: report subagent usage to parent RunJournal on every
  terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including
  the CancelledError path, guarded against double-reporting.

No DB migration needed — RunRow columns already exist.

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix: address token usage review feedback

* Address review follow-ups

---------

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
YuJitang
2026-05-10 22:47:30 +08:00
committed by GitHub
parent 94da8f67d7
commit 9892a7d468
8 changed files with 843 additions and 77 deletions
+238
View File
@@ -383,6 +383,244 @@ class TestMiddlewareEvents:
assert "middleware:guardrail" in event_types
class TestCallerBucketing:
"""Tests for caller-bucketed token accumulation (lead_agent / subagent / middleware)."""
def test_lead_agent_bucketing(self, journal_setup):
j, _ = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 0
assert j._middleware_tokens == 0
def test_subagent_bucketing(self, journal_setup):
j, _ = journal_setup
usage = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
assert j._subagent_tokens == 30
assert j._lead_agent_tokens == 0
assert j._middleware_tokens == 0
def test_middleware_bucketing(self, journal_setup):
j, _ = journal_setup
usage = {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7}
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarize"])
assert j._middleware_tokens == 7
assert j._lead_agent_tokens == 0
assert j._subagent_tokens == 0
def test_mixed_callers_sum_independently(self, journal_setup):
j, _ = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:bash"])
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:title"])
assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 15
assert j._middleware_tokens == 15
assert j._total_tokens == 45
def test_get_completion_data_includes_buckets(self, journal_setup):
j, _ = journal_setup
j._lead_agent_tokens = 100
j._subagent_tokens = 200
j._middleware_tokens = 50
data = j.get_completion_data()
assert data["lead_agent_tokens"] == 100
assert data["subagent_tokens"] == 200
assert data["middleware_tokens"] == 50
def test_dedup_same_run_id(self, journal_setup):
"""Same langchain run_id in on_llm_end must not double-count."""
j, _ = journal_setup
run_id = uuid4()
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
assert j._total_tokens == 15
assert j._lead_agent_tokens == 15
assert j._llm_call_count == 1
def test_first_no_usage_second_with_usage(self, journal_setup):
"""First callback with no usage must not block second callback with usage for same run_id."""
j, _ = journal_setup
run_id = uuid4()
j.on_llm_end(_make_llm_response("A", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
assert str(run_id) not in j._counted_llm_run_ids
# Second callback for the same run_id with actual usage must still count
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
assert j._total_tokens == 15
assert j._lead_agent_tokens == 15
def test_track_token_usage_false_skips_buckets(self):
"""When token tracking is disabled, caller buckets stay at 0."""
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("X", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
assert j._subagent_tokens == 0
assert j._lead_agent_tokens == 0
def test_default_no_tags_buckets_as_lead_agent(self, journal_setup):
"""LLM calls without explicit tags default to lead_agent bucket."""
j, _ = journal_setup
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None)
assert j._lead_agent_tokens == 10
assert j._subagent_tokens == 0
assert j._middleware_tokens == 0
def test_unknown_tag_buckets_as_lead_agent(self, journal_setup):
"""Calls with unrecognized tags (not lead_agent/subagent:/middleware:) go to lead_agent."""
j, _ = journal_setup
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["some_random_tag"])
assert j._lead_agent_tokens == 10
class TestExternalUsageRecords:
"""Tests for record_external_llm_usage_records."""
def test_records_added_to_subagent_bucket(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "ext-1",
"caller": "subagent:general-purpose",
"input_tokens": 100,
"output_tokens": 50,
"total_tokens": 150,
}
]
j.record_external_llm_usage_records(records)
assert j._subagent_tokens == 150
assert j._total_tokens == 150
assert j._total_input_tokens == 100
assert j._total_output_tokens == 50
def test_records_added_to_middleware_bucket(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "ext-2",
"caller": "middleware:summarize",
"input_tokens": 30,
"output_tokens": 10,
"total_tokens": 40,
}
]
j.record_external_llm_usage_records(records)
assert j._middleware_tokens == 40
assert j._lead_agent_tokens == 0
assert j._subagent_tokens == 0
def test_records_added_to_lead_agent_bucket(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "ext-3",
"caller": "lead_agent",
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15,
}
]
j.record_external_llm_usage_records(records)
assert j._lead_agent_tokens == 15
def test_dedup_same_source_run_id(self, journal_setup):
"""Same source_run_id must not be double-counted."""
j, _ = journal_setup
records = [
{
"source_run_id": "dup-1",
"caller": "subagent:research",
"input_tokens": 50,
"output_tokens": 25,
"total_tokens": 75,
}
]
j.record_external_llm_usage_records(records)
j.record_external_llm_usage_records(records)
assert j._subagent_tokens == 75
assert j._total_tokens == 75
def test_total_tokens_missing_computed_from_input_output(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "ext-4",
"caller": "subagent:bash",
"input_tokens": 200,
"output_tokens": 100,
"total_tokens": 0,
}
]
j.record_external_llm_usage_records(records)
assert j._subagent_tokens == 300
assert j._total_tokens == 300
def test_total_tokens_zero_no_count(self, journal_setup):
"""Records with zero total and zero input+output must not be counted."""
j, _ = journal_setup
records = [
{
"source_run_id": "ext-5",
"caller": "subagent:research",
"input_tokens": 0,
"output_tokens": 0,
"total_tokens": 0,
}
]
j.record_external_llm_usage_records(records)
assert j._total_tokens == 0
assert j._subagent_tokens == 0
def test_empty_source_run_id_skipped(self, journal_setup):
j, _ = journal_setup
records = [
{
"source_run_id": "",
"caller": "subagent:research",
"input_tokens": 50,
"output_tokens": 25,
"total_tokens": 75,
}
]
j.record_external_llm_usage_records(records)
assert j._total_tokens == 0
def test_multiple_records_in_single_call(self, journal_setup):
j, _ = journal_setup
records = [
{"source_run_id": "r1", "caller": "subagent:gp", "input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
{"source_run_id": "r2", "caller": "subagent:bash", "input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
]
j.record_external_llm_usage_records(records)
assert j._subagent_tokens == 45
assert j._total_tokens == 45
def test_external_records_coexist_with_inline_callbacks(self, journal_setup):
"""External records and inline on_llm_end must not interfere."""
j, _ = journal_setup
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
j.record_external_llm_usage_records([{"source_run_id": "ext-6", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
assert j._lead_agent_tokens == 15
assert j._subagent_tokens == 150
assert j._total_tokens == 165
def test_track_token_usage_false_skips_external_records(self):
"""When token tracking is disabled, external records must not accumulate."""
store = MemoryRunEventStore()
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
j.record_external_llm_usage_records([{"source_run_id": "ext-7", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
assert j._total_tokens == 0
assert j._subagent_tokens == 0
class TestChatModelStartHumanMessage:
"""Tests for on_chat_model_start extracting the first human message."""
@@ -0,0 +1,161 @@
"""Tests for SubagentTokenCollector callback handler."""
from unittest.mock import MagicMock
from uuid import uuid4
from deerflow.subagents.token_collector import SubagentTokenCollector
def _make_llm_response(content="Hello", usage=None):
"""Create a mock LLM response with a message."""
msg = MagicMock()
msg.content = content
msg.usage_metadata = usage
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
return response
def _make_llm_response_from_usages(usages):
"""Create a mock LLM response with one generation per usage entry."""
generations = []
for usage in usages:
msg = MagicMock()
msg.content = "chunk"
msg.usage_metadata = usage
gen = MagicMock()
gen.message = msg
generations.append([gen])
response = MagicMock()
response.generations = generations
return response
class TestSubagentTokenCollector:
def test_collects_usage_from_response(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["caller"] == "subagent:test"
assert records[0]["input_tokens"] == 100
assert records[0]["output_tokens"] == 50
assert records[0]["total_tokens"] == 150
assert "source_run_id" in records[0]
def test_total_tokens_zero_uses_input_plus_output(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 200, "output_tokens": 100, "total_tokens": 0}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["total_tokens"] == 300
def test_total_tokens_missing_uses_input_plus_output(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 30, "output_tokens": 20}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["total_tokens"] == 50
def test_dedup_same_run_id(self):
collector = SubagentTokenCollector(caller="subagent:test")
run_id = uuid4()
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id)
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id)
records = collector.snapshot_records()
assert len(records) == 1
def test_no_usage_no_record(self):
collector = SubagentTokenCollector(caller="subagent:test")
collector.on_llm_end(_make_llm_response("Hi", usage=None), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
def test_zero_usage_no_record(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
def test_skips_empty_generation_and_records_later_usage(self):
collector = SubagentTokenCollector(caller="subagent:test")
response = _make_llm_response_from_usages(
[
None,
{"input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
]
)
collector.on_llm_end(response, run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 1
assert records[0]["total_tokens"] == 30
def test_snapshot_returns_copy(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
snap1 = collector.snapshot_records()
snap2 = collector.snapshot_records()
assert snap1 == snap2
assert snap1 is not snap2
# Mutating snapshot does not affect internal records
snap1.append({"source_run_id": "fake"})
assert len(collector.snapshot_records()) == 1
def test_multiple_calls_accumulate(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4())
collector.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 2
def test_different_run_ids_accumulate_separately(self):
collector = SubagentTokenCollector(caller="subagent:test")
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
collector.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4())
collector.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 2
assert records[0]["total_tokens"] == 15
assert records[1]["total_tokens"] == 30
def test_message_without_usage_metadata_skipped(self):
"""A response where message has no usage_metadata attribute must be skipped."""
collector = SubagentTokenCollector(caller="subagent:test")
msg = MagicMock(spec=[]) # object without usage_metadata
gen = MagicMock()
gen.message = msg
response = MagicMock()
response.generations = [[gen]]
collector.on_llm_end(response, run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
def test_generation_without_message_skipped(self):
"""A generation without a message attribute must be skipped."""
collector = SubagentTokenCollector(caller="subagent:test")
gen = MagicMock(spec=[]) # object without message
response = MagicMock()
response.generations = [[gen]]
collector.on_llm_end(response, run_id=uuid4())
records = collector.snapshot_records()
assert len(records) == 0
+187 -42
View File
@@ -777,22 +777,27 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
def test_cleanup_scheduled_on_cancellation(monkeypatch):
"""Verify cancellation schedules deferred cleanup for the background task."""
"""Verify cancellation handler synchronously cleans up after shielded wait."""
config = _make_subagent_config()
events = []
cleanup_calls = []
scheduled_cleanup_coros = []
poll_count = 0
def get_result(_: str):
nonlocal poll_count
poll_count += 1
if poll_count == 1:
# Main loop polls RUNNING twice, then shielded wait gets COMPLETED
if poll_count <= 2:
return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
return _make_result(FakeSubagentStatus.COMPLETED, result="done")
async def cancel_on_first_sleep(_: float) -> None:
raise asyncio.CancelledError
sleep_count = 0
async def cancel_on_second_sleep(_: float) -> None:
nonlocal sleep_count
sleep_count += 1
if sleep_count == 2:
raise asyncio.CancelledError
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
@@ -804,12 +809,7 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
monkeypatch.setattr(
task_tool_module.asyncio,
"create_task",
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_second_sleep)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
@@ -826,25 +826,48 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
tool_call_id="tc-cancelled-cleanup",
)
assert cleanup_calls == []
assert len(scheduled_cleanup_coros) == 1
asyncio.run(scheduled_cleanup_coros.pop())
# Cleanup happens synchronously within the cancellation handler
assert cleanup_calls == ["tc-cancelled-cleanup"]
def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
"""Verify deferred cleanup gives up after a bounded number of polls."""
"""Verify cancellation handler survives a shielded-wait timeout gracefully.
When the subagent never reaches a terminal state, the shielded wait times
out (or is interrupted), the handler reports whatever usage it can, calls
cleanup (which is a no-op for non-terminal tasks), and re-raises.
"""
config = _make_subagent_config()
config.timeout_seconds = 1
events = []
report_calls = []
cleanup_calls = []
scheduled_cleanup_coros = []
scheduled_cleanups = []
# Always return RUNNING — subagent never finishes
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
async def cancel_on_first_sleep(_: float) -> None:
raise asyncio.CancelledError
def fake_report_subagent_usage(runtime, result):
report_calls.append((runtime, result))
class DummyCleanupTask:
def __init__(self, coro):
self.coro = coro
def add_done_callback(self, callback):
self.callback = callback
def fake_create_task(coro):
scheduled_cleanups.append(coro)
coro.close()
return DummyCleanupTask(coro)
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
@@ -852,19 +875,10 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(
task_tool_module,
"get_background_task_result",
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
monkeypatch.setattr(
task_tool_module.asyncio,
"create_task",
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
)
monkeypatch.setattr(task_tool_module.asyncio, "create_task", fake_create_task)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
@@ -881,13 +895,73 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
tool_call_id="tc-cancelled-timeout",
)
async def bounded_sleep(_seconds: float) -> None:
return None
monkeypatch.setattr(task_tool_module.asyncio, "sleep", bounded_sleep)
asyncio.run(scheduled_cleanup_coros.pop())
# Non-terminal tasks cannot be cleaned immediately; a deferred cleanup
# keeps polling after the parent cancellation path exits.
assert cleanup_calls == []
assert len(scheduled_cleanups) == 1
# _report_subagent_usage is called (but skips because result has no records)
assert len(report_calls) == 1
def test_cancellation_wait_uses_subagent_polling_budget(monkeypatch):
"""Cancelled parent waits on the existing subagent polling budget, not a fixed timeout."""
config = _make_subagent_config()
events = []
report_calls = []
cleanup_calls = []
sleep_count = 0
result_polls = 0
terminal_result = _make_result(FakeSubagentStatus.COMPLETED, result="done")
def get_result(_: str):
nonlocal result_polls
result_polls += 1
if result_polls < 5:
return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
return terminal_result
async def cancel_then_continue(_: float) -> None:
nonlocal sleep_count
sleep_count += 1
if sleep_count == 1:
raise asyncio.CancelledError
def fake_report_subagent_usage(runtime, result):
report_calls.append((runtime, result))
async def fail_on_fixed_timeout(awaitable, *, timeout=None):
raise AssertionError(f"cancellation wait should not use fixed timeout={timeout}")
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_then_continue)
monkeypatch.setattr(task_tool_module.asyncio, "wait_for", fail_on_fixed_timeout)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
"cleanup_background_task",
lambda task_id: cleanup_calls.append(task_id),
)
with pytest.raises(asyncio.CancelledError):
_run_task_tool(
runtime=_make_runtime(),
description="执行任务",
prompt="cancel task",
subagent_type="general-purpose",
tool_call_id="tc-cancel-budget",
)
assert report_calls == [(_make_runtime(), terminal_result)]
assert cleanup_calls == ["tc-cancel-budget"]
def test_cancellation_calls_request_cancel(monkeypatch):
@@ -895,7 +969,6 @@ def test_cancellation_calls_request_cancel(monkeypatch):
config = _make_subagent_config()
events = []
cancel_requests = []
scheduled_cleanup_coros = []
async def cancel_on_first_sleep(_: float) -> None:
raise asyncio.CancelledError
@@ -915,11 +988,6 @@ def test_cancellation_calls_request_cancel(monkeypatch):
)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
monkeypatch.setattr(
task_tool_module.asyncio,
"create_task",
lambda coro: (coro.close(), scheduled_cleanup_coros.append(None))[-1] or _DummyScheduledTask(),
)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(
task_tool_module,
@@ -987,3 +1055,80 @@ def test_task_tool_returns_cancelled_message(monkeypatch):
assert output == "Task cancelled by user."
assert any(e.get("type") == "task_cancelled" for e in events)
assert cleanup_calls == ["tc-poll-cancelled"]
def test_cancellation_reports_subagent_usage(monkeypatch):
"""Verify cancellation handler waits (shielded) for subagent terminal state,
then reports the final token usage before re-raising CancelledError.
The report must happen synchronously within the cancellation handler so
the parent worker's finally block sees the updated journal totals.
"""
config = _make_subagent_config()
events = []
report_calls = []
cleanup_calls = []
# Terminal result with token usage collected after cancellation processing
cancel_result = _make_result(FakeSubagentStatus.CANCELLED, error="Cancelled by user")
cancel_result.token_usage_records = [{"source_run_id": "sub-run-1", "caller": "subagent:gp", "input_tokens": 50, "output_tokens": 25, "total_tokens": 75}]
cancel_result.usage_reported = False
poll_count = 0
def get_result(_: str):
nonlocal poll_count
poll_count += 1
# Main loop polls 3 times (RUNNING each time to keep looping)
if poll_count <= 3:
running = _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
running.token_usage_records = []
running.usage_reported = False
return running
# Shielded wait poll gets the terminal result
return cancel_result
sleep_count = 0
async def cancel_on_third_sleep(_: float) -> None:
nonlocal sleep_count
sleep_count += 1
if sleep_count == 3:
raise asyncio.CancelledError
def fake_report_subagent_usage(runtime, result):
report_calls.append((runtime, result))
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_third_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage)
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
monkeypatch.setattr(task_tool_module, "request_cancel_background_task", lambda _: None)
monkeypatch.setattr(
task_tool_module,
"cleanup_background_task",
lambda task_id: cleanup_calls.append(task_id),
)
with pytest.raises(asyncio.CancelledError):
_run_task_tool(
runtime=_make_runtime(),
description="执行任务",
prompt="cancel me",
subagent_type="general-purpose",
tool_call_id="tc-cancel-report",
)
# _report_subagent_usage is called synchronously within the cancellation
# handler (after the shielded wait), before CancelledError is re-raised.
assert len(report_calls) == 1
assert report_calls[0][1] is cancel_result
assert cleanup_calls == ["tc-cancel-report"]