mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
fix: bucket subagent token usage into parent run totals (#2838)
* fix: bucket subagent token usage into RunRow.subagent_tokens Add caller-bucketed token tracking to RunJournal so subagent and middleware LLM calls are written to the correct RunRow columns instead of all falling into lead_agent_tokens (default 0). - RunJournal: accumulate _lead_agent_tokens / _subagent_tokens / _middleware_tokens in on_llm_end, deduped by langchain run_id. Add record_external_llm_usage_records() for external sources (respects track_token_usage flag). Return caller buckets from get_completion_data(). - SubagentTokenCollector: new lightweight callback handler that collects LLM usage within subagent execution. - SubagentExecutor: wire collector into subagent run_config and sync records to SubagentResult on every chunk (timeout/cancel safe). - SubagentResult: add token_usage_records and usage_reported fields. - task_tool: report subagent usage to parent RunJournal on every terminal status (COMPLETED/FAILED/CANCELLED/TIMED_OUT), including the CancelledError path, guarded against double-reporting. No DB migration needed — RunRow columns already exist. * Potential fix for pull request finding Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix: address token usage review feedback * Address review follow-ups --------- Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -63,6 +63,15 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
self._total_tokens = 0
|
self._total_tokens = 0
|
||||||
self._llm_call_count = 0
|
self._llm_call_count = 0
|
||||||
|
|
||||||
|
# Caller-bucketed token accumulators
|
||||||
|
self._lead_agent_tokens = 0
|
||||||
|
self._subagent_tokens = 0
|
||||||
|
self._middleware_tokens = 0
|
||||||
|
|
||||||
|
# Dedup: LangChain may fire on_llm_end multiple times for the same run_id
|
||||||
|
self._counted_llm_run_ids: set[str] = set()
|
||||||
|
self._counted_external_source_ids: set[str] = set()
|
||||||
|
|
||||||
# Convenience fields
|
# Convenience fields
|
||||||
self._last_ai_msg: str | None = None
|
self._last_ai_msg: str | None = None
|
||||||
self._first_human_msg: str | None = None
|
self._first_human_msg: str | None = None
|
||||||
@@ -214,19 +223,28 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Token accumulation
|
# Token accumulation (dedup by langchain run_id to avoid double-counting
|
||||||
|
# when the callback fires more than once for the same response)
|
||||||
if self._track_tokens:
|
if self._track_tokens:
|
||||||
input_tk = usage_dict.get("input_tokens", 0) or 0
|
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||||
output_tk = usage_dict.get("output_tokens", 0) or 0
|
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||||
total_tk = usage_dict.get("total_tokens", 0) or 0
|
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||||
if total_tk == 0:
|
if total_tk == 0:
|
||||||
total_tk = input_tk + output_tk
|
total_tk = input_tk + output_tk
|
||||||
if total_tk > 0:
|
if total_tk > 0 and rid not in self._counted_llm_run_ids:
|
||||||
|
self._counted_llm_run_ids.add(rid)
|
||||||
self._total_input_tokens += input_tk
|
self._total_input_tokens += input_tk
|
||||||
self._total_output_tokens += output_tk
|
self._total_output_tokens += output_tk
|
||||||
self._total_tokens += total_tk
|
self._total_tokens += total_tk
|
||||||
self._llm_call_count += 1
|
self._llm_call_count += 1
|
||||||
|
|
||||||
|
if caller.startswith("subagent:"):
|
||||||
|
self._subagent_tokens += total_tk
|
||||||
|
elif caller.startswith("middleware:"):
|
||||||
|
self._middleware_tokens += total_tk
|
||||||
|
else:
|
||||||
|
self._lead_agent_tokens += total_tk
|
||||||
|
|
||||||
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
def on_llm_error(self, error: BaseException, *, run_id: UUID, **kwargs: Any) -> None:
|
||||||
self._llm_start_times.pop(str(run_id), None)
|
self._llm_start_times.pop(str(run_id), None)
|
||||||
self._put(event_type="llm.error", category="trace", content=str(error))
|
self._put(event_type="llm.error", category="trace", content=str(error))
|
||||||
@@ -330,6 +348,49 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
|
|
||||||
# -- Public methods (called by worker) --
|
# -- Public methods (called by worker) --
|
||||||
|
|
||||||
|
def record_external_llm_usage_records(
|
||||||
|
self,
|
||||||
|
records: list[dict[str, int | str]],
|
||||||
|
) -> None:
|
||||||
|
"""Record token usage from external sources (e.g., subagents).
|
||||||
|
|
||||||
|
Each record should contain:
|
||||||
|
source_run_id: Unique identifier to prevent double-counting
|
||||||
|
caller: Caller tag (e.g. "subagent:general-purpose")
|
||||||
|
input_tokens: Input token count
|
||||||
|
output_tokens: Output token count
|
||||||
|
total_tokens: Total token count (computed from input+output if 0/missing)
|
||||||
|
"""
|
||||||
|
if not self._track_tokens:
|
||||||
|
return
|
||||||
|
for record in records:
|
||||||
|
source_id = str(record.get("source_run_id", ""))
|
||||||
|
if not source_id:
|
||||||
|
continue
|
||||||
|
if source_id in self._counted_external_source_ids:
|
||||||
|
continue
|
||||||
|
|
||||||
|
total_tk = record.get("total_tokens", 0) or 0
|
||||||
|
if total_tk <= 0:
|
||||||
|
input_tk = record.get("input_tokens", 0) or 0
|
||||||
|
output_tk = record.get("output_tokens", 0) or 0
|
||||||
|
total_tk = input_tk + output_tk
|
||||||
|
if total_tk <= 0:
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._counted_external_source_ids.add(source_id)
|
||||||
|
self._total_input_tokens += record.get("input_tokens", 0) or 0
|
||||||
|
self._total_output_tokens += record.get("output_tokens", 0) or 0
|
||||||
|
self._total_tokens += total_tk
|
||||||
|
|
||||||
|
caller = str(record.get("caller", ""))
|
||||||
|
if caller.startswith("subagent:"):
|
||||||
|
self._subagent_tokens += total_tk
|
||||||
|
elif caller.startswith("middleware:"):
|
||||||
|
self._middleware_tokens += total_tk
|
||||||
|
else:
|
||||||
|
self._lead_agent_tokens += total_tk
|
||||||
|
|
||||||
def set_first_human_message(self, content: str) -> None:
|
def set_first_human_message(self, content: str) -> None:
|
||||||
"""Record the first human message for convenience fields."""
|
"""Record the first human message for convenience fields."""
|
||||||
self._first_human_msg = content[:2000] if content else None
|
self._first_human_msg = content[:2000] if content else None
|
||||||
@@ -376,6 +437,9 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
"total_output_tokens": self._total_output_tokens,
|
"total_output_tokens": self._total_output_tokens,
|
||||||
"total_tokens": self._total_tokens,
|
"total_tokens": self._total_tokens,
|
||||||
"llm_call_count": self._llm_call_count,
|
"llm_call_count": self._llm_call_count,
|
||||||
|
"lead_agent_tokens": self._lead_agent_tokens,
|
||||||
|
"subagent_tokens": self._subagent_tokens,
|
||||||
|
"middleware_tokens": self._middleware_tokens,
|
||||||
"message_count": self._msg_count,
|
"message_count": self._msg_count,
|
||||||
"last_ai_message": self._last_ai_msg,
|
"last_ai_message": self._last_ai_msg,
|
||||||
"first_human_message": self._first_human_msg,
|
"first_human_message": self._first_human_msg,
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from deerflow.models import create_chat_model
|
|||||||
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
from deerflow.skills.tool_policy import filter_tools_by_skill_allowed_tools
|
||||||
from deerflow.skills.types import Skill
|
from deerflow.skills.types import Skill
|
||||||
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
|
from deerflow.subagents.config import SubagentConfig, resolve_subagent_model_name
|
||||||
|
from deerflow.subagents.token_collector import SubagentTokenCollector
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -70,6 +71,8 @@ class SubagentResult:
|
|||||||
started_at: datetime | None = None
|
started_at: datetime | None = None
|
||||||
completed_at: datetime | None = None
|
completed_at: datetime | None = None
|
||||||
ai_messages: list[dict[str, Any]] | None = None
|
ai_messages: list[dict[str, Any]] | None = None
|
||||||
|
token_usage_records: list[dict[str, int | str]] = field(default_factory=list)
|
||||||
|
usage_reported: bool = False
|
||||||
cancel_event: threading.Event = field(default_factory=threading.Event, repr=False)
|
cancel_event: threading.Event = field(default_factory=threading.Event, repr=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
@@ -412,13 +415,20 @@ class SubagentExecutor:
|
|||||||
ai_messages = []
|
ai_messages = []
|
||||||
result.ai_messages = ai_messages
|
result.ai_messages = ai_messages
|
||||||
|
|
||||||
|
collector: SubagentTokenCollector | None = None
|
||||||
try:
|
try:
|
||||||
state, filtered_tools = await self._build_initial_state(task)
|
state, filtered_tools = await self._build_initial_state(task)
|
||||||
agent = self._create_agent(filtered_tools)
|
agent = self._create_agent(filtered_tools)
|
||||||
|
|
||||||
|
# Token collector for subagent LLM calls
|
||||||
|
collector_caller = f"subagent:{self.config.name}"
|
||||||
|
collector = SubagentTokenCollector(caller=collector_caller)
|
||||||
|
|
||||||
# Build config with thread_id for sandbox access and recursion limit
|
# Build config with thread_id for sandbox access and recursion limit
|
||||||
run_config: RunnableConfig = {
|
run_config: RunnableConfig = {
|
||||||
"recursion_limit": self.config.max_turns,
|
"recursion_limit": self.config.max_turns,
|
||||||
|
"callbacks": [collector],
|
||||||
|
"tags": [collector_caller],
|
||||||
}
|
}
|
||||||
context: dict[str, Any] = {}
|
context: dict[str, Any] = {}
|
||||||
if self.thread_id:
|
if self.thread_id:
|
||||||
@@ -441,6 +451,8 @@ class SubagentExecutor:
|
|||||||
result.status = SubagentStatus.CANCELLED
|
result.status = SubagentStatus.CANCELLED
|
||||||
result.error = "Cancelled by user"
|
result.error = "Cancelled by user"
|
||||||
result.completed_at = datetime.now()
|
result.completed_at = datetime.now()
|
||||||
|
if collector is not None:
|
||||||
|
result.token_usage_records = collector.snapshot_records()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
async for chunk in agent.astream(state, config=run_config, context=context, stream_mode="values"): # type: ignore[arg-type]
|
||||||
@@ -455,6 +467,7 @@ class SubagentExecutor:
|
|||||||
result.status = SubagentStatus.CANCELLED
|
result.status = SubagentStatus.CANCELLED
|
||||||
result.error = "Cancelled by user"
|
result.error = "Cancelled by user"
|
||||||
result.completed_at = datetime.now()
|
result.completed_at = datetime.now()
|
||||||
|
result.token_usage_records = collector.snapshot_records()
|
||||||
return result
|
return result
|
||||||
|
|
||||||
final_state = chunk
|
final_state = chunk
|
||||||
@@ -481,6 +494,7 @@ class SubagentExecutor:
|
|||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}")
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} captured AI message #{len(ai_messages)}")
|
||||||
|
|
||||||
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
|
logger.info(f"[trace={self.trace_id}] Subagent {self.config.name} completed async execution")
|
||||||
|
result.token_usage_records = collector.snapshot_records()
|
||||||
|
|
||||||
if final_state is None:
|
if final_state is None:
|
||||||
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
|
logger.warning(f"[trace={self.trace_id}] Subagent {self.config.name} no final state")
|
||||||
@@ -560,6 +574,8 @@ class SubagentExecutor:
|
|||||||
result.status = SubagentStatus.FAILED
|
result.status = SubagentStatus.FAILED
|
||||||
result.error = str(e)
|
result.error = str(e)
|
||||||
result.completed_at = datetime.now()
|
result.completed_at = datetime.now()
|
||||||
|
if collector is not None:
|
||||||
|
result.token_usage_records = collector.snapshot_records()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,63 @@
|
|||||||
|
"""Callback handler that collects LLM token usage within a subagent.
|
||||||
|
|
||||||
|
Each subagent execution creates its own collector. After the subagent
|
||||||
|
finishes, the collected records are transferred to the parent RunJournal
|
||||||
|
via :meth:`RunJournal.record_external_llm_usage_records`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.callbacks import BaseCallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
class SubagentTokenCollector(BaseCallbackHandler):
|
||||||
|
"""Lightweight callback handler that collects LLM token usage within a subagent."""
|
||||||
|
|
||||||
|
def __init__(self, caller: str):
|
||||||
|
super().__init__()
|
||||||
|
self.caller = caller
|
||||||
|
self._records: list[dict[str, int | str]] = []
|
||||||
|
self._counted_run_ids: set[str] = set()
|
||||||
|
|
||||||
|
def on_llm_end(
|
||||||
|
self,
|
||||||
|
response: Any,
|
||||||
|
*,
|
||||||
|
run_id: Any,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
rid = str(run_id)
|
||||||
|
if rid in self._counted_run_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
for generation in response.generations:
|
||||||
|
for gen in generation:
|
||||||
|
if not hasattr(gen, "message"):
|
||||||
|
continue
|
||||||
|
usage = getattr(gen.message, "usage_metadata", None)
|
||||||
|
usage_dict = dict(usage) if usage else {}
|
||||||
|
input_tk = usage_dict.get("input_tokens", 0) or 0
|
||||||
|
output_tk = usage_dict.get("output_tokens", 0) or 0
|
||||||
|
total_tk = usage_dict.get("total_tokens", 0) or 0
|
||||||
|
if total_tk <= 0:
|
||||||
|
total_tk = input_tk + output_tk
|
||||||
|
if total_tk <= 0:
|
||||||
|
continue
|
||||||
|
self._counted_run_ids.add(rid)
|
||||||
|
self._records.append(
|
||||||
|
{
|
||||||
|
"source_run_id": rid,
|
||||||
|
"caller": self.caller,
|
||||||
|
"input_tokens": input_tk,
|
||||||
|
"output_tokens": output_tk,
|
||||||
|
"total_tokens": total_tk,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
def snapshot_records(self) -> list[dict[str, int | str]]:
|
||||||
|
"""Return a copy of the accumulated usage records."""
|
||||||
|
return list(self._records)
|
||||||
@@ -27,6 +27,92 @@ if TYPE_CHECKING:
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_subagent_terminal(result: Any) -> bool:
|
||||||
|
"""Return whether a background subagent result is safe to clean up."""
|
||||||
|
return result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def _await_subagent_terminal(task_id: str, max_polls: int) -> Any | None:
|
||||||
|
"""Poll until the background subagent reaches a terminal status or we run out of polls."""
|
||||||
|
for _ in range(max_polls):
|
||||||
|
result = get_background_task_result(task_id)
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
if _is_subagent_terminal(result):
|
||||||
|
return result
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _deferred_cleanup_subagent_task(task_id: str, trace_id: str, max_polls: int) -> None:
|
||||||
|
"""Keep polling a cancelled subagent until it can be safely removed."""
|
||||||
|
cleanup_poll_count = 0
|
||||||
|
while True:
|
||||||
|
result = get_background_task_result(task_id)
|
||||||
|
if result is None:
|
||||||
|
return
|
||||||
|
if _is_subagent_terminal(result):
|
||||||
|
cleanup_background_task(task_id)
|
||||||
|
return
|
||||||
|
if cleanup_poll_count >= max_polls:
|
||||||
|
logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls")
|
||||||
|
return
|
||||||
|
await asyncio.sleep(5)
|
||||||
|
cleanup_poll_count += 1
|
||||||
|
|
||||||
|
|
||||||
|
def _log_cleanup_failure(cleanup_task: asyncio.Task[None], *, trace_id: str, task_id: str) -> None:
|
||||||
|
if cleanup_task.cancelled():
|
||||||
|
return
|
||||||
|
|
||||||
|
exc = cleanup_task.exception()
|
||||||
|
if exc is not None:
|
||||||
|
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
|
||||||
|
|
||||||
|
|
||||||
|
def _schedule_deferred_subagent_cleanup(task_id: str, trace_id: str, max_polls: int) -> None:
|
||||||
|
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
|
||||||
|
cleanup_task = asyncio.create_task(_deferred_cleanup_subagent_task(task_id, trace_id, max_polls))
|
||||||
|
cleanup_task.add_done_callback(lambda task: _log_cleanup_failure(task, trace_id=trace_id, task_id=task_id))
|
||||||
|
|
||||||
|
|
||||||
|
def _find_usage_recorder(runtime: Any) -> Any | None:
|
||||||
|
"""Find a callback handler with ``record_external_llm_usage_records`` in the runtime config."""
|
||||||
|
if runtime is None:
|
||||||
|
return None
|
||||||
|
config = getattr(runtime, "config", None)
|
||||||
|
if not isinstance(config, dict):
|
||||||
|
return None
|
||||||
|
callbacks = config.get("callbacks", [])
|
||||||
|
if not callbacks:
|
||||||
|
return None
|
||||||
|
for cb in callbacks:
|
||||||
|
if hasattr(cb, "record_external_llm_usage_records"):
|
||||||
|
return cb
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _report_subagent_usage(runtime: Any, result: Any) -> None:
|
||||||
|
"""Report subagent token usage to the parent RunJournal, if available.
|
||||||
|
|
||||||
|
Each subagent task must be reported only once (guarded by usage_reported).
|
||||||
|
"""
|
||||||
|
if getattr(result, "usage_reported", True):
|
||||||
|
return
|
||||||
|
records = getattr(result, "token_usage_records", None) or []
|
||||||
|
if not records:
|
||||||
|
return
|
||||||
|
journal = _find_usage_recorder(runtime)
|
||||||
|
if journal is None:
|
||||||
|
logger.debug("No usage recorder found in runtime callbacks — subagent token usage not recorded")
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
journal.record_external_llm_usage_records(records)
|
||||||
|
result.usage_reported = True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to report subagent token usage", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
def _get_runtime_app_config(runtime: Any) -> "AppConfig | None":
|
def _get_runtime_app_config(runtime: Any) -> "AppConfig | None":
|
||||||
context = getattr(runtime, "context", None)
|
context = getattr(runtime, "context", None)
|
||||||
if isinstance(context, dict):
|
if isinstance(context, dict):
|
||||||
@@ -227,21 +313,25 @@ async def task_tool(
|
|||||||
|
|
||||||
# Check if task completed, failed, or timed out
|
# Check if task completed, failed, or timed out
|
||||||
if result.status == SubagentStatus.COMPLETED:
|
if result.status == SubagentStatus.COMPLETED:
|
||||||
|
_report_subagent_usage(runtime, result)
|
||||||
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
|
||||||
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return f"Task Succeeded. Result: {result.result}"
|
return f"Task Succeeded. Result: {result.result}"
|
||||||
elif result.status == SubagentStatus.FAILED:
|
elif result.status == SubagentStatus.FAILED:
|
||||||
|
_report_subagent_usage(runtime, result)
|
||||||
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
|
||||||
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return f"Task failed. Error: {result.error}"
|
return f"Task failed. Error: {result.error}"
|
||||||
elif result.status == SubagentStatus.CANCELLED:
|
elif result.status == SubagentStatus.CANCELLED:
|
||||||
|
_report_subagent_usage(runtime, result)
|
||||||
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
|
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
|
||||||
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
return "Task cancelled by user."
|
return "Task cancelled by user."
|
||||||
elif result.status == SubagentStatus.TIMED_OUT:
|
elif result.status == SubagentStatus.TIMED_OUT:
|
||||||
|
_report_subagent_usage(runtime, result)
|
||||||
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
|
||||||
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
|
||||||
cleanup_background_task(task_id)
|
cleanup_background_task(task_id)
|
||||||
@@ -260,43 +350,28 @@ async def task_tool(
|
|||||||
if poll_count > max_poll_count:
|
if poll_count > max_poll_count:
|
||||||
timeout_minutes = config.timeout_seconds // 60
|
timeout_minutes = config.timeout_seconds // 60
|
||||||
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
|
||||||
|
_report_subagent_usage(runtime, result)
|
||||||
writer({"type": "task_timed_out", "task_id": task_id})
|
writer({"type": "task_timed_out", "task_id": task_id})
|
||||||
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
# Signal the background subagent thread to stop cooperatively.
|
# Signal the background subagent thread to stop cooperatively.
|
||||||
# Without this, the thread (running in ThreadPoolExecutor with its
|
|
||||||
# own event loop via asyncio.run) would continue executing even
|
|
||||||
# after the parent task is cancelled.
|
|
||||||
request_cancel_background_task(task_id)
|
request_cancel_background_task(task_id)
|
||||||
|
|
||||||
async def cleanup_when_done() -> None:
|
# Wait (shielded) for the subagent to reach a terminal state so the
|
||||||
max_cleanup_polls = max_poll_count
|
# final token usage snapshot is reported to the parent RunJournal
|
||||||
cleanup_poll_count = 0
|
# before the parent worker persists get_completion_data().
|
||||||
|
terminal_result = None
|
||||||
|
try:
|
||||||
|
terminal_result = await asyncio.shield(_await_subagent_terminal(task_id, max_poll_count))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
while True:
|
# Report whatever the subagent collected (even if we timed out).
|
||||||
result = get_background_task_result(task_id)
|
final_result = terminal_result or get_background_task_result(task_id)
|
||||||
if result is None:
|
if final_result is not None:
|
||||||
return
|
_report_subagent_usage(runtime, final_result)
|
||||||
|
if final_result is not None and _is_subagent_terminal(final_result):
|
||||||
if result.status in {SubagentStatus.COMPLETED, SubagentStatus.FAILED, SubagentStatus.CANCELLED, SubagentStatus.TIMED_OUT} or getattr(result, "completed_at", None) is not None:
|
cleanup_background_task(task_id)
|
||||||
cleanup_background_task(task_id)
|
else:
|
||||||
return
|
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
|
||||||
|
|
||||||
if cleanup_poll_count > max_cleanup_polls:
|
|
||||||
logger.warning(f"[trace={trace_id}] Deferred cleanup for task {task_id} timed out after {cleanup_poll_count} polls")
|
|
||||||
return
|
|
||||||
|
|
||||||
await asyncio.sleep(5)
|
|
||||||
cleanup_poll_count += 1
|
|
||||||
|
|
||||||
def log_cleanup_failure(cleanup_task: asyncio.Task[None]) -> None:
|
|
||||||
if cleanup_task.cancelled():
|
|
||||||
return
|
|
||||||
|
|
||||||
exc = cleanup_task.exception()
|
|
||||||
if exc is not None:
|
|
||||||
logger.error(f"[trace={trace_id}] Deferred cleanup failed for task {task_id}: {exc}")
|
|
||||||
|
|
||||||
logger.debug(f"[trace={trace_id}] Scheduling deferred cleanup for cancelled task {task_id}")
|
|
||||||
asyncio.create_task(cleanup_when_done()).add_done_callback(log_cleanup_failure)
|
|
||||||
raise
|
raise
|
||||||
|
|||||||
@@ -383,6 +383,244 @@ class TestMiddlewareEvents:
|
|||||||
assert "middleware:guardrail" in event_types
|
assert "middleware:guardrail" in event_types
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallerBucketing:
|
||||||
|
"""Tests for caller-bucketed token accumulation (lead_agent / subagent / middleware)."""
|
||||||
|
|
||||||
|
def test_lead_agent_bucketing(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||||
|
assert j._lead_agent_tokens == 15
|
||||||
|
assert j._subagent_tokens == 0
|
||||||
|
assert j._middleware_tokens == 0
|
||||||
|
|
||||||
|
def test_subagent_bucketing(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
usage = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
|
||||||
|
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
||||||
|
assert j._subagent_tokens == 30
|
||||||
|
assert j._lead_agent_tokens == 0
|
||||||
|
assert j._middleware_tokens == 0
|
||||||
|
|
||||||
|
def test_middleware_bucketing(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
usage = {"input_tokens": 5, "output_tokens": 2, "total_tokens": 7}
|
||||||
|
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:summarize"])
|
||||||
|
assert j._middleware_tokens == 7
|
||||||
|
assert j._lead_agent_tokens == 0
|
||||||
|
assert j._subagent_tokens == 0
|
||||||
|
|
||||||
|
def test_mixed_callers_sum_independently(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||||
|
j.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:bash"])
|
||||||
|
j.on_llm_end(_make_llm_response("C", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["middleware:title"])
|
||||||
|
assert j._lead_agent_tokens == 15
|
||||||
|
assert j._subagent_tokens == 15
|
||||||
|
assert j._middleware_tokens == 15
|
||||||
|
assert j._total_tokens == 45
|
||||||
|
|
||||||
|
def test_get_completion_data_includes_buckets(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
j._lead_agent_tokens = 100
|
||||||
|
j._subagent_tokens = 200
|
||||||
|
j._middleware_tokens = 50
|
||||||
|
data = j.get_completion_data()
|
||||||
|
assert data["lead_agent_tokens"] == 100
|
||||||
|
assert data["subagent_tokens"] == 200
|
||||||
|
assert data["middleware_tokens"] == 50
|
||||||
|
|
||||||
|
def test_dedup_same_run_id(self, journal_setup):
|
||||||
|
"""Same langchain run_id in on_llm_end must not double-count."""
|
||||||
|
j, _ = journal_setup
|
||||||
|
run_id = uuid4()
|
||||||
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||||
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||||
|
assert j._total_tokens == 15
|
||||||
|
assert j._lead_agent_tokens == 15
|
||||||
|
assert j._llm_call_count == 1
|
||||||
|
|
||||||
|
def test_first_no_usage_second_with_usage(self, journal_setup):
|
||||||
|
"""First callback with no usage must not block second callback with usage for same run_id."""
|
||||||
|
j, _ = journal_setup
|
||||||
|
run_id = uuid4()
|
||||||
|
j.on_llm_end(_make_llm_response("A", usage=None), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||||
|
assert str(run_id) not in j._counted_llm_run_ids
|
||||||
|
# Second callback for the same run_id with actual usage must still count
|
||||||
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id, parent_run_id=None, tags=["lead_agent"])
|
||||||
|
assert j._total_tokens == 15
|
||||||
|
assert j._lead_agent_tokens == 15
|
||||||
|
|
||||||
|
def test_track_token_usage_false_skips_buckets(self):
|
||||||
|
"""When token tracking is disabled, caller buckets stay at 0."""
|
||||||
|
store = MemoryRunEventStore()
|
||||||
|
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
|
||||||
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
j.on_llm_end(_make_llm_response("X", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["subagent:research"])
|
||||||
|
assert j._subagent_tokens == 0
|
||||||
|
assert j._lead_agent_tokens == 0
|
||||||
|
|
||||||
|
def test_default_no_tags_buckets_as_lead_agent(self, journal_setup):
|
||||||
|
"""LLM calls without explicit tags default to lead_agent bucket."""
|
||||||
|
j, _ = journal_setup
|
||||||
|
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
|
||||||
|
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None)
|
||||||
|
assert j._lead_agent_tokens == 10
|
||||||
|
assert j._subagent_tokens == 0
|
||||||
|
assert j._middleware_tokens == 0
|
||||||
|
|
||||||
|
def test_unknown_tag_buckets_as_lead_agent(self, journal_setup):
|
||||||
|
"""Calls with unrecognized tags (not lead_agent/subagent:/middleware:) go to lead_agent."""
|
||||||
|
j, _ = journal_setup
|
||||||
|
usage = {"input_tokens": 5, "output_tokens": 5, "total_tokens": 10}
|
||||||
|
j.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["some_random_tag"])
|
||||||
|
assert j._lead_agent_tokens == 10
|
||||||
|
|
||||||
|
|
||||||
|
class TestExternalUsageRecords:
|
||||||
|
"""Tests for record_external_llm_usage_records."""
|
||||||
|
|
||||||
|
def test_records_added_to_subagent_bucket(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"source_run_id": "ext-1",
|
||||||
|
"caller": "subagent:general-purpose",
|
||||||
|
"input_tokens": 100,
|
||||||
|
"output_tokens": 50,
|
||||||
|
"total_tokens": 150,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
j.record_external_llm_usage_records(records)
|
||||||
|
assert j._subagent_tokens == 150
|
||||||
|
assert j._total_tokens == 150
|
||||||
|
assert j._total_input_tokens == 100
|
||||||
|
assert j._total_output_tokens == 50
|
||||||
|
|
||||||
|
def test_records_added_to_middleware_bucket(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"source_run_id": "ext-2",
|
||||||
|
"caller": "middleware:summarize",
|
||||||
|
"input_tokens": 30,
|
||||||
|
"output_tokens": 10,
|
||||||
|
"total_tokens": 40,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
j.record_external_llm_usage_records(records)
|
||||||
|
assert j._middleware_tokens == 40
|
||||||
|
assert j._lead_agent_tokens == 0
|
||||||
|
assert j._subagent_tokens == 0
|
||||||
|
|
||||||
|
def test_records_added_to_lead_agent_bucket(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"source_run_id": "ext-3",
|
||||||
|
"caller": "lead_agent",
|
||||||
|
"input_tokens": 10,
|
||||||
|
"output_tokens": 5,
|
||||||
|
"total_tokens": 15,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
j.record_external_llm_usage_records(records)
|
||||||
|
assert j._lead_agent_tokens == 15
|
||||||
|
|
||||||
|
def test_dedup_same_source_run_id(self, journal_setup):
|
||||||
|
"""Same source_run_id must not be double-counted."""
|
||||||
|
j, _ = journal_setup
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"source_run_id": "dup-1",
|
||||||
|
"caller": "subagent:research",
|
||||||
|
"input_tokens": 50,
|
||||||
|
"output_tokens": 25,
|
||||||
|
"total_tokens": 75,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
j.record_external_llm_usage_records(records)
|
||||||
|
j.record_external_llm_usage_records(records)
|
||||||
|
assert j._subagent_tokens == 75
|
||||||
|
assert j._total_tokens == 75
|
||||||
|
|
||||||
|
def test_total_tokens_missing_computed_from_input_output(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"source_run_id": "ext-4",
|
||||||
|
"caller": "subagent:bash",
|
||||||
|
"input_tokens": 200,
|
||||||
|
"output_tokens": 100,
|
||||||
|
"total_tokens": 0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
j.record_external_llm_usage_records(records)
|
||||||
|
assert j._subagent_tokens == 300
|
||||||
|
assert j._total_tokens == 300
|
||||||
|
|
||||||
|
def test_total_tokens_zero_no_count(self, journal_setup):
|
||||||
|
"""Records with zero total and zero input+output must not be counted."""
|
||||||
|
j, _ = journal_setup
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"source_run_id": "ext-5",
|
||||||
|
"caller": "subagent:research",
|
||||||
|
"input_tokens": 0,
|
||||||
|
"output_tokens": 0,
|
||||||
|
"total_tokens": 0,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
j.record_external_llm_usage_records(records)
|
||||||
|
assert j._total_tokens == 0
|
||||||
|
assert j._subagent_tokens == 0
|
||||||
|
|
||||||
|
def test_empty_source_run_id_skipped(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"source_run_id": "",
|
||||||
|
"caller": "subagent:research",
|
||||||
|
"input_tokens": 50,
|
||||||
|
"output_tokens": 25,
|
||||||
|
"total_tokens": 75,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
j.record_external_llm_usage_records(records)
|
||||||
|
assert j._total_tokens == 0
|
||||||
|
|
||||||
|
def test_multiple_records_in_single_call(self, journal_setup):
|
||||||
|
j, _ = journal_setup
|
||||||
|
records = [
|
||||||
|
{"source_run_id": "r1", "caller": "subagent:gp", "input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
|
||||||
|
{"source_run_id": "r2", "caller": "subagent:bash", "input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
|
||||||
|
]
|
||||||
|
j.record_external_llm_usage_records(records)
|
||||||
|
assert j._subagent_tokens == 45
|
||||||
|
assert j._total_tokens == 45
|
||||||
|
|
||||||
|
def test_external_records_coexist_with_inline_callbacks(self, journal_setup):
|
||||||
|
"""External records and inline on_llm_end must not interfere."""
|
||||||
|
j, _ = journal_setup
|
||||||
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
j.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
||||||
|
j.record_external_llm_usage_records([{"source_run_id": "ext-6", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
|
||||||
|
assert j._lead_agent_tokens == 15
|
||||||
|
assert j._subagent_tokens == 150
|
||||||
|
assert j._total_tokens == 165
|
||||||
|
|
||||||
|
def test_track_token_usage_false_skips_external_records(self):
|
||||||
|
"""When token tracking is disabled, external records must not accumulate."""
|
||||||
|
store = MemoryRunEventStore()
|
||||||
|
j = RunJournal("r1", "t1", store, track_token_usage=False, flush_threshold=100)
|
||||||
|
j.record_external_llm_usage_records([{"source_run_id": "ext-7", "caller": "subagent:gp", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150}])
|
||||||
|
assert j._total_tokens == 0
|
||||||
|
assert j._subagent_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
class TestChatModelStartHumanMessage:
|
class TestChatModelStartHumanMessage:
|
||||||
"""Tests for on_chat_model_start extracting the first human message."""
|
"""Tests for on_chat_model_start extracting the first human message."""
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,161 @@
|
|||||||
|
"""Tests for SubagentTokenCollector callback handler."""
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from deerflow.subagents.token_collector import SubagentTokenCollector
|
||||||
|
|
||||||
|
|
||||||
|
def _make_llm_response(content="Hello", usage=None):
|
||||||
|
"""Create a mock LLM response with a message."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.content = content
|
||||||
|
msg.usage_metadata = usage
|
||||||
|
|
||||||
|
gen = MagicMock()
|
||||||
|
gen.message = msg
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.generations = [[gen]]
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def _make_llm_response_from_usages(usages):
|
||||||
|
"""Create a mock LLM response with one generation per usage entry."""
|
||||||
|
generations = []
|
||||||
|
for usage in usages:
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.content = "chunk"
|
||||||
|
msg.usage_metadata = usage
|
||||||
|
|
||||||
|
gen = MagicMock()
|
||||||
|
gen.message = msg
|
||||||
|
generations.append([gen])
|
||||||
|
|
||||||
|
response = MagicMock()
|
||||||
|
response.generations = generations
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
class TestSubagentTokenCollector:
|
||||||
|
def test_collects_usage_from_response(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
usage = {"input_tokens": 100, "output_tokens": 50, "total_tokens": 150}
|
||||||
|
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0]["caller"] == "subagent:test"
|
||||||
|
assert records[0]["input_tokens"] == 100
|
||||||
|
assert records[0]["output_tokens"] == 50
|
||||||
|
assert records[0]["total_tokens"] == 150
|
||||||
|
assert "source_run_id" in records[0]
|
||||||
|
|
||||||
|
def test_total_tokens_zero_uses_input_plus_output(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
usage = {"input_tokens": 200, "output_tokens": 100, "total_tokens": 0}
|
||||||
|
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0]["total_tokens"] == 300
|
||||||
|
|
||||||
|
def test_total_tokens_missing_uses_input_plus_output(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
usage = {"input_tokens": 30, "output_tokens": 20}
|
||||||
|
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0]["total_tokens"] == 50
|
||||||
|
|
||||||
|
def test_dedup_same_run_id(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
run_id = uuid4()
|
||||||
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id)
|
||||||
|
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=run_id)
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 1
|
||||||
|
|
||||||
|
def test_no_usage_no_record(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
collector.on_llm_end(_make_llm_response("Hi", usage=None), run_id=uuid4())
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 0
|
||||||
|
|
||||||
|
def test_zero_usage_no_record(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
usage = {"input_tokens": 0, "output_tokens": 0, "total_tokens": 0}
|
||||||
|
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 0
|
||||||
|
|
||||||
|
def test_skips_empty_generation_and_records_later_usage(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
response = _make_llm_response_from_usages(
|
||||||
|
[
|
||||||
|
None,
|
||||||
|
{"input_tokens": 20, "output_tokens": 10, "total_tokens": 30},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
collector.on_llm_end(response, run_id=uuid4())
|
||||||
|
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0]["total_tokens"] == 30
|
||||||
|
|
||||||
|
def test_snapshot_returns_copy(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
collector.on_llm_end(_make_llm_response("Hi", usage=usage), run_id=uuid4())
|
||||||
|
snap1 = collector.snapshot_records()
|
||||||
|
snap2 = collector.snapshot_records()
|
||||||
|
assert snap1 == snap2
|
||||||
|
assert snap1 is not snap2
|
||||||
|
# Mutating snapshot does not affect internal records
|
||||||
|
snap1.append({"source_run_id": "fake"})
|
||||||
|
assert len(collector.snapshot_records()) == 1
|
||||||
|
|
||||||
|
def test_multiple_calls_accumulate(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
collector.on_llm_end(_make_llm_response("A", usage=usage), run_id=uuid4())
|
||||||
|
collector.on_llm_end(_make_llm_response("B", usage=usage), run_id=uuid4())
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 2
|
||||||
|
|
||||||
|
def test_different_run_ids_accumulate_separately(self):
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
usage1 = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||||
|
usage2 = {"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}
|
||||||
|
collector.on_llm_end(_make_llm_response("A", usage=usage1), run_id=uuid4())
|
||||||
|
collector.on_llm_end(_make_llm_response("B", usage=usage2), run_id=uuid4())
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 2
|
||||||
|
assert records[0]["total_tokens"] == 15
|
||||||
|
assert records[1]["total_tokens"] == 30
|
||||||
|
|
||||||
|
def test_message_without_usage_metadata_skipped(self):
|
||||||
|
"""A response where message has no usage_metadata attribute must be skipped."""
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
|
||||||
|
msg = MagicMock(spec=[]) # object without usage_metadata
|
||||||
|
gen = MagicMock()
|
||||||
|
gen.message = msg
|
||||||
|
response = MagicMock()
|
||||||
|
response.generations = [[gen]]
|
||||||
|
|
||||||
|
collector.on_llm_end(response, run_id=uuid4())
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 0
|
||||||
|
|
||||||
|
def test_generation_without_message_skipped(self):
|
||||||
|
"""A generation without a message attribute must be skipped."""
|
||||||
|
collector = SubagentTokenCollector(caller="subagent:test")
|
||||||
|
|
||||||
|
gen = MagicMock(spec=[]) # object without message
|
||||||
|
response = MagicMock()
|
||||||
|
response.generations = [[gen]]
|
||||||
|
|
||||||
|
collector.on_llm_end(response, run_id=uuid4())
|
||||||
|
records = collector.snapshot_records()
|
||||||
|
assert len(records) == 0
|
||||||
@@ -777,22 +777,27 @@ def test_cleanup_not_called_on_polling_safety_timeout(monkeypatch):
|
|||||||
|
|
||||||
|
|
||||||
def test_cleanup_scheduled_on_cancellation(monkeypatch):
|
def test_cleanup_scheduled_on_cancellation(monkeypatch):
|
||||||
"""Verify cancellation schedules deferred cleanup for the background task."""
|
"""Verify cancellation handler synchronously cleans up after shielded wait."""
|
||||||
config = _make_subagent_config()
|
config = _make_subagent_config()
|
||||||
events = []
|
events = []
|
||||||
cleanup_calls = []
|
cleanup_calls = []
|
||||||
scheduled_cleanup_coros = []
|
|
||||||
poll_count = 0
|
poll_count = 0
|
||||||
|
|
||||||
def get_result(_: str):
|
def get_result(_: str):
|
||||||
nonlocal poll_count
|
nonlocal poll_count
|
||||||
poll_count += 1
|
poll_count += 1
|
||||||
if poll_count == 1:
|
# Main loop polls RUNNING twice, then shielded wait gets COMPLETED
|
||||||
|
if poll_count <= 2:
|
||||||
return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
|
return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
|
||||||
return _make_result(FakeSubagentStatus.COMPLETED, result="done")
|
return _make_result(FakeSubagentStatus.COMPLETED, result="done")
|
||||||
|
|
||||||
async def cancel_on_first_sleep(_: float) -> None:
|
sleep_count = 0
|
||||||
raise asyncio.CancelledError
|
|
||||||
|
async def cancel_on_second_sleep(_: float) -> None:
|
||||||
|
nonlocal sleep_count
|
||||||
|
sleep_count += 1
|
||||||
|
if sleep_count == 2:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
@@ -804,12 +809,7 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
|
|||||||
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
|
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_second_sleep)
|
||||||
monkeypatch.setattr(
|
|
||||||
task_tool_module.asyncio,
|
|
||||||
"create_task",
|
|
||||||
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
task_tool_module,
|
task_tool_module,
|
||||||
@@ -826,25 +826,48 @@ def test_cleanup_scheduled_on_cancellation(monkeypatch):
|
|||||||
tool_call_id="tc-cancelled-cleanup",
|
tool_call_id="tc-cancelled-cleanup",
|
||||||
)
|
)
|
||||||
|
|
||||||
assert cleanup_calls == []
|
# Cleanup happens synchronously within the cancellation handler
|
||||||
assert len(scheduled_cleanup_coros) == 1
|
|
||||||
|
|
||||||
asyncio.run(scheduled_cleanup_coros.pop())
|
|
||||||
|
|
||||||
assert cleanup_calls == ["tc-cancelled-cleanup"]
|
assert cleanup_calls == ["tc-cancelled-cleanup"]
|
||||||
|
|
||||||
|
|
||||||
def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
|
def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
|
||||||
"""Verify deferred cleanup gives up after a bounded number of polls."""
|
"""Verify cancellation handler survives a shielded-wait timeout gracefully.
|
||||||
|
|
||||||
|
When the subagent never reaches a terminal state, the shielded wait times
|
||||||
|
out (or is interrupted), the handler reports whatever usage it can, calls
|
||||||
|
cleanup (which is a no-op for non-terminal tasks), and re-raises.
|
||||||
|
"""
|
||||||
config = _make_subagent_config()
|
config = _make_subagent_config()
|
||||||
config.timeout_seconds = 1
|
|
||||||
events = []
|
events = []
|
||||||
|
report_calls = []
|
||||||
cleanup_calls = []
|
cleanup_calls = []
|
||||||
scheduled_cleanup_coros = []
|
scheduled_cleanups = []
|
||||||
|
|
||||||
|
# Always return RUNNING — subagent never finishes
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"get_background_task_result",
|
||||||
|
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
||||||
|
)
|
||||||
|
|
||||||
async def cancel_on_first_sleep(_: float) -> None:
|
async def cancel_on_first_sleep(_: float) -> None:
|
||||||
raise asyncio.CancelledError
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
def fake_report_subagent_usage(runtime, result):
|
||||||
|
report_calls.append((runtime, result))
|
||||||
|
|
||||||
|
class DummyCleanupTask:
|
||||||
|
def __init__(self, coro):
|
||||||
|
self.coro = coro
|
||||||
|
|
||||||
|
def add_done_callback(self, callback):
|
||||||
|
self.callback = callback
|
||||||
|
|
||||||
|
def fake_create_task(coro):
|
||||||
|
scheduled_cleanups.append(coro)
|
||||||
|
coro.close()
|
||||||
|
return DummyCleanupTask(coro)
|
||||||
|
|
||||||
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
task_tool_module,
|
task_tool_module,
|
||||||
@@ -852,19 +875,10 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
|
|||||||
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||||
)
|
)
|
||||||
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
task_tool_module,
|
|
||||||
"get_background_task_result",
|
|
||||||
lambda _: _make_result(FakeSubagentStatus.RUNNING, ai_messages=[]),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(task_tool_module.asyncio, "create_task", fake_create_task)
|
||||||
task_tool_module.asyncio,
|
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage)
|
||||||
"create_task",
|
|
||||||
lambda coro: scheduled_cleanup_coros.append(coro) or _DummyScheduledTask(),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
task_tool_module,
|
task_tool_module,
|
||||||
@@ -881,13 +895,73 @@ def test_cancelled_cleanup_stops_after_timeout(monkeypatch):
|
|||||||
tool_call_id="tc-cancelled-timeout",
|
tool_call_id="tc-cancelled-timeout",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def bounded_sleep(_seconds: float) -> None:
|
# Non-terminal tasks cannot be cleaned immediately; a deferred cleanup
|
||||||
return None
|
# keeps polling after the parent cancellation path exits.
|
||||||
|
|
||||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", bounded_sleep)
|
|
||||||
asyncio.run(scheduled_cleanup_coros.pop())
|
|
||||||
|
|
||||||
assert cleanup_calls == []
|
assert cleanup_calls == []
|
||||||
|
assert len(scheduled_cleanups) == 1
|
||||||
|
# _report_subagent_usage is called (but skips because result has no records)
|
||||||
|
assert len(report_calls) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancellation_wait_uses_subagent_polling_budget(monkeypatch):
|
||||||
|
"""Cancelled parent waits on the existing subagent polling budget, not a fixed timeout."""
|
||||||
|
config = _make_subagent_config()
|
||||||
|
events = []
|
||||||
|
report_calls = []
|
||||||
|
cleanup_calls = []
|
||||||
|
sleep_count = 0
|
||||||
|
result_polls = 0
|
||||||
|
terminal_result = _make_result(FakeSubagentStatus.COMPLETED, result="done")
|
||||||
|
|
||||||
|
def get_result(_: str):
|
||||||
|
nonlocal result_polls
|
||||||
|
result_polls += 1
|
||||||
|
if result_polls < 5:
|
||||||
|
return _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
|
||||||
|
return terminal_result
|
||||||
|
|
||||||
|
async def cancel_then_continue(_: float) -> None:
|
||||||
|
nonlocal sleep_count
|
||||||
|
sleep_count += 1
|
||||||
|
if sleep_count == 1:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
def fake_report_subagent_usage(runtime, result):
|
||||||
|
report_calls.append((runtime, result))
|
||||||
|
|
||||||
|
async def fail_on_fixed_timeout(awaitable, *, timeout=None):
|
||||||
|
raise AssertionError(f"cancellation wait should not use fixed timeout={timeout}")
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"SubagentExecutor",
|
||||||
|
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_then_continue)
|
||||||
|
monkeypatch.setattr(task_tool_module.asyncio, "wait_for", fail_on_fixed_timeout)
|
||||||
|
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage)
|
||||||
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"cleanup_background_task",
|
||||||
|
lambda task_id: cleanup_calls.append(task_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
_run_task_tool(
|
||||||
|
runtime=_make_runtime(),
|
||||||
|
description="执行任务",
|
||||||
|
prompt="cancel task",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-cancel-budget",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert report_calls == [(_make_runtime(), terminal_result)]
|
||||||
|
assert cleanup_calls == ["tc-cancel-budget"]
|
||||||
|
|
||||||
|
|
||||||
def test_cancellation_calls_request_cancel(monkeypatch):
|
def test_cancellation_calls_request_cancel(monkeypatch):
|
||||||
@@ -895,7 +969,6 @@ def test_cancellation_calls_request_cancel(monkeypatch):
|
|||||||
config = _make_subagent_config()
|
config = _make_subagent_config()
|
||||||
events = []
|
events = []
|
||||||
cancel_requests = []
|
cancel_requests = []
|
||||||
scheduled_cleanup_coros = []
|
|
||||||
|
|
||||||
async def cancel_on_first_sleep(_: float) -> None:
|
async def cancel_on_first_sleep(_: float) -> None:
|
||||||
raise asyncio.CancelledError
|
raise asyncio.CancelledError
|
||||||
@@ -915,11 +988,6 @@ def test_cancellation_calls_request_cancel(monkeypatch):
|
|||||||
)
|
)
|
||||||
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_first_sleep)
|
||||||
monkeypatch.setattr(
|
|
||||||
task_tool_module.asyncio,
|
|
||||||
"create_task",
|
|
||||||
lambda coro: (coro.close(), scheduled_cleanup_coros.append(None))[-1] or _DummyScheduledTask(),
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
task_tool_module,
|
task_tool_module,
|
||||||
@@ -987,3 +1055,80 @@ def test_task_tool_returns_cancelled_message(monkeypatch):
|
|||||||
assert output == "Task cancelled by user."
|
assert output == "Task cancelled by user."
|
||||||
assert any(e.get("type") == "task_cancelled" for e in events)
|
assert any(e.get("type") == "task_cancelled" for e in events)
|
||||||
assert cleanup_calls == ["tc-poll-cancelled"]
|
assert cleanup_calls == ["tc-poll-cancelled"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancellation_reports_subagent_usage(monkeypatch):
|
||||||
|
"""Verify cancellation handler waits (shielded) for subagent terminal state,
|
||||||
|
then reports the final token usage before re-raising CancelledError.
|
||||||
|
|
||||||
|
The report must happen synchronously within the cancellation handler so
|
||||||
|
the parent worker's finally block sees the updated journal totals.
|
||||||
|
"""
|
||||||
|
config = _make_subagent_config()
|
||||||
|
events = []
|
||||||
|
report_calls = []
|
||||||
|
cleanup_calls = []
|
||||||
|
|
||||||
|
# Terminal result with token usage collected after cancellation processing
|
||||||
|
cancel_result = _make_result(FakeSubagentStatus.CANCELLED, error="Cancelled by user")
|
||||||
|
cancel_result.token_usage_records = [{"source_run_id": "sub-run-1", "caller": "subagent:gp", "input_tokens": 50, "output_tokens": 25, "total_tokens": 75}]
|
||||||
|
cancel_result.usage_reported = False
|
||||||
|
|
||||||
|
poll_count = 0
|
||||||
|
|
||||||
|
def get_result(_: str):
|
||||||
|
nonlocal poll_count
|
||||||
|
poll_count += 1
|
||||||
|
# Main loop polls 3 times (RUNNING each time to keep looping)
|
||||||
|
if poll_count <= 3:
|
||||||
|
running = _make_result(FakeSubagentStatus.RUNNING, ai_messages=[])
|
||||||
|
running.token_usage_records = []
|
||||||
|
running.usage_reported = False
|
||||||
|
return running
|
||||||
|
# Shielded wait poll gets the terminal result
|
||||||
|
return cancel_result
|
||||||
|
|
||||||
|
sleep_count = 0
|
||||||
|
|
||||||
|
async def cancel_on_third_sleep(_: float) -> None:
|
||||||
|
nonlocal sleep_count
|
||||||
|
sleep_count += 1
|
||||||
|
if sleep_count == 3:
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
def fake_report_subagent_usage(runtime, result):
|
||||||
|
report_calls.append((runtime, result))
|
||||||
|
|
||||||
|
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"SubagentExecutor",
|
||||||
|
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_background_task_result", get_result)
|
||||||
|
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
|
||||||
|
monkeypatch.setattr(task_tool_module.asyncio, "sleep", cancel_on_third_sleep)
|
||||||
|
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", fake_report_subagent_usage)
|
||||||
|
monkeypatch.setattr("deerflow.tools.get_available_tools", lambda **kwargs: [])
|
||||||
|
monkeypatch.setattr(task_tool_module, "request_cancel_background_task", lambda _: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
task_tool_module,
|
||||||
|
"cleanup_background_task",
|
||||||
|
lambda task_id: cleanup_calls.append(task_id),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
_run_task_tool(
|
||||||
|
runtime=_make_runtime(),
|
||||||
|
description="执行任务",
|
||||||
|
prompt="cancel me",
|
||||||
|
subagent_type="general-purpose",
|
||||||
|
tool_call_id="tc-cancel-report",
|
||||||
|
)
|
||||||
|
|
||||||
|
# _report_subagent_usage is called synchronously within the cancellation
|
||||||
|
# handler (after the shielded wait), before CancelledError is re-raised.
|
||||||
|
assert len(report_calls) == 1
|
||||||
|
assert report_calls[0][1] is cancel_result
|
||||||
|
assert cleanup_calls == ["tc-cancel-report"]
|
||||||
|
|||||||
@@ -20,7 +20,11 @@ test("fetchThreadTokenUsage uses shared auth fetch without JSON GET headers", as
|
|||||||
total_tokens: 7,
|
total_tokens: 7,
|
||||||
total_runs: 1,
|
total_runs: 1,
|
||||||
by_model: { unknown: { tokens: 7, runs: 1 } },
|
by_model: { unknown: { tokens: 7, runs: 1 } },
|
||||||
by_caller: {},
|
by_caller: {
|
||||||
|
lead_agent: 0,
|
||||||
|
subagent: 0,
|
||||||
|
middleware: 0,
|
||||||
|
},
|
||||||
}),
|
}),
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user