diff --git a/backend/app/gateway/routers/thread_runs.py b/backend/app/gateway/routers/thread_runs.py index 294fa9799..a542593b2 100644 --- a/backend/app/gateway/routers/thread_runs.py +++ b/backend/app/gateway/routers/thread_runs.py @@ -66,6 +66,14 @@ class RunResponse(BaseModel): multitask_strategy: str = "reject" created_at: str = "" updated_at: str = "" + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_tokens: int = 0 + llm_call_count: int = 0 + lead_agent_tokens: int = 0 + subagent_tokens: int = 0 + middleware_tokens: int = 0 + message_count: int = 0 class ThreadTokenUsageModelBreakdown(BaseModel): @@ -111,6 +119,14 @@ def _record_to_response(record: RunRecord) -> RunResponse: multitask_strategy=record.multitask_strategy, created_at=record.created_at, updated_at=record.updated_at, + total_input_tokens=record.total_input_tokens, + total_output_tokens=record.total_output_tokens, + total_tokens=record.total_tokens, + llm_call_count=record.llm_call_count, + lead_agent_tokens=record.lead_agent_tokens, + subagent_tokens=record.subagent_tokens, + middleware_tokens=record.middleware_tokens, + message_count=record.message_count, ) @@ -402,8 +418,15 @@ async def list_run_events( @router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse) @require_permission("threads", "read", owner_check=True) -async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse: +async def thread_token_usage( + thread_id: str, + request: Request, + include_active: bool = Query(default=False, description="Include running run progress snapshots"), +) -> ThreadTokenUsageResponse: """Thread-level token usage aggregation.""" run_store = get_run_store(request) - agg = await run_store.aggregate_tokens_by_thread(thread_id) + if include_active: + agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True) + else: + agg = await run_store.aggregate_tokens_by_thread(thread_id) return ThreadTokenUsageResponse(thread_id=thread_id, **agg) diff --git a/backend/packages/harness/deerflow/persistence/run/sql.py b/backend/packages/harness/deerflow/persistence/run/sql.py index 5679cc68f..1be9fb159 100644 --- a/backend/packages/harness/deerflow/persistence/run/sql.py +++ b/backend/packages/harness/deerflow/persistence/run/sql.py @@ -227,9 +227,48 @@ class RunRepository(RunStore): await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.commit() - async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: + async def update_run_progress( + self, + run_id: str, + *, + total_input_tokens: int | None = None, + total_output_tokens: int | None = None, + total_tokens: int | None = None, + llm_call_count: int | None = None, + lead_agent_tokens: int | None = None, + subagent_tokens: int | None = None, + middleware_tokens: int | None = None, + message_count: int | None = None, + last_ai_message: str | None = None, + first_human_message: str | None = None, + ) -> None: + """Update token usage + convenience fields while a run is still active.""" + values: dict[str, Any] = {"updated_at": datetime.now(UTC)} + optional_counters = { + "total_input_tokens": total_input_tokens, + "total_output_tokens": total_output_tokens, + "total_tokens": total_tokens, + "llm_call_count": llm_call_count, + "lead_agent_tokens": lead_agent_tokens, + "subagent_tokens": subagent_tokens, + "middleware_tokens": middleware_tokens, + "message_count": message_count, + } + for key, value in optional_counters.items(): + if value is not None: + values[key] = value + if last_ai_message is not None: + values["last_ai_message"] = last_ai_message[:2000] + if first_human_message is not None: + values["first_human_message"] = first_human_message[:2000] + async with self._sf() as session: + await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values)) + await session.commit() + + async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: """Aggregate token usage via a single SQL GROUP BY query.""" - _completed = RunRow.status.in_(("success", "error")) + statuses = ("success", "error", "running") if include_active else ("success", "error") + _completed = RunRow.status.in_(statuses) _thread = RunRow.thread_id == thread_id model_name = func.coalesce(RunRow.model_name, "unknown") diff --git a/backend/packages/harness/deerflow/runtime/journal.py b/backend/packages/harness/deerflow/runtime/journal.py index 8a9382e23..a12ebd98b 100644 --- a/backend/packages/harness/deerflow/runtime/journal.py +++ b/backend/packages/harness/deerflow/runtime/journal.py @@ -20,7 +20,7 @@ from __future__ import annotations import asyncio import logging import time -from collections.abc import Mapping +from collections.abc import Awaitable, Callable, Mapping from datetime import UTC, datetime from typing import TYPE_CHECKING, Any, cast from uuid import UUID @@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler): *, track_token_usage: bool = True, flush_threshold: int = 20, + progress_reporter: Callable[[dict], Awaitable[None]] | None = None, + progress_flush_interval: float = 5.0, ): super().__init__() self.run_id = run_id @@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler): self._store = event_store self._track_tokens = track_token_usage self._flush_threshold = flush_threshold + self._progress_reporter = progress_reporter + self._progress_flush_interval = progress_flush_interval # Write buffer self._buffer: list[dict] = [] self._pending_flush_tasks: set[asyncio.Task[None]] = set() + self._pending_progress_task: asyncio.Task[None] | None = None + self._pending_progress_delayed = False + self._progress_dirty = False + self._last_progress_flush = 0.0 # Token accumulators self._total_input_tokens = 0 @@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler): else: self._lead_agent_tokens += total_tk + self._schedule_progress_flush() + if messages: self._counted_message_llm_run_ids.add(str(run_id)) @@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler): else: self._lead_agent_tokens += total_tk + self._schedule_progress_flush() + def set_first_human_message(self, content: str) -> None: """Record the first human message for convenience fields.""" self._first_human_msg = content[:2000] if content else None @@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler): """Force flush remaining buffer. Called in worker's finally block.""" if self._pending_flush_tasks: await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True) + while self._pending_progress_task is not None and not self._pending_progress_task.done(): + if self._pending_progress_delayed: + self._pending_progress_task.cancel() + await asyncio.gather(self._pending_progress_task, return_exceptions=True) + self._progress_dirty = False + self._pending_progress_delayed = False + break + await asyncio.gather(self._pending_progress_task, return_exceptions=True) while self._buffer: batch = self._buffer[: self._flush_threshold] @@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler): self._buffer = batch + self._buffer raise + def _schedule_progress_flush(self) -> None: + """Best-effort throttled progress snapshot for active run visibility.""" + if self._progress_reporter is None: + return + now = time.monotonic() + elapsed = now - self._last_progress_flush + if elapsed < self._progress_flush_interval: + self._progress_dirty = True + self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed) + return + if self._pending_progress_task is not None and not self._pending_progress_task.done(): + self._progress_dirty = True + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + self._progress_dirty = False + self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data())) + + def _schedule_delayed_progress_flush(self, delay: float) -> None: + if self._pending_progress_task is not None and not self._pending_progress_task.done(): + return + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + delay = max(0.0, delay) + self._pending_progress_delayed = delay > 0 + self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay)) + + async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None: + if self._progress_reporter is None: + return + if delay > 0: + self._pending_progress_delayed = True + await asyncio.sleep(delay) + self._pending_progress_delayed = False + dirty_before_write = self._progress_dirty + self._progress_dirty = False + snapshot_to_write = snapshot or self.get_completion_data() + try: + await self._progress_reporter(snapshot_to_write) + self._last_progress_flush = time.monotonic() + except Exception: + logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True) + if dirty_before_write or self._progress_dirty: + self._progress_dirty = False + self._pending_progress_task = None + self._schedule_delayed_progress_flush(self._progress_flush_interval) + def get_completion_data(self) -> dict: """Return accumulated token and message data for run completion.""" return { diff --git a/backend/packages/harness/deerflow/runtime/runs/manager.py b/backend/packages/harness/deerflow/runtime/runs/manager.py index ea78f89c9..5387689dc 100644 --- a/backend/packages/harness/deerflow/runtime/runs/manager.py +++ b/backend/packages/harness/deerflow/runtime/runs/manager.py @@ -38,6 +38,16 @@ class RunRecord: error: str | None = None model_name: str | None = None store_only: bool = False + total_input_tokens: int = 0 + total_output_tokens: int = 0 + total_tokens: int = 0 + llm_call_count: int = 0 + lead_agent_tokens: int = 0 + subagent_tokens: int = 0 + middleware_tokens: int = 0 + message_count: int = 0 + last_ai_message: str | None = None + first_human_message: str | None = None class RunManager: @@ -102,16 +112,53 @@ class RunManager: error=row.get("error"), model_name=row.get("model_name"), store_only=True, + total_input_tokens=row.get("total_input_tokens") or 0, + total_output_tokens=row.get("total_output_tokens") or 0, + total_tokens=row.get("total_tokens") or 0, + llm_call_count=row.get("llm_call_count") or 0, + lead_agent_tokens=row.get("lead_agent_tokens") or 0, + subagent_tokens=row.get("subagent_tokens") or 0, + middleware_tokens=row.get("middleware_tokens") or 0, + message_count=row.get("message_count") or 0, + last_ai_message=row.get("last_ai_message"), + first_human_message=row.get("first_human_message"), ) async def update_run_completion(self, run_id: str, **kwargs) -> None: """Persist token usage and completion data to the backing store.""" + async with self._lock: + record = self._runs.get(run_id) + if record is not None: + for key, value in kwargs.items(): + if key == "status": + continue + if hasattr(record, key) and value is not None: + setattr(record, key, value) + record.updated_at = _now_iso() if self._store is not None: try: await self._store.update_run_completion(run_id, **kwargs) except Exception: logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) + async def update_run_progress(self, run_id: str, **kwargs) -> None: + """Persist a running token/message snapshot without changing status.""" + should_persist = True + async with self._lock: + record = self._runs.get(run_id) + if record is not None: + should_persist = record.status == RunStatus.running + if record is not None and should_persist: + for key, value in kwargs.items(): + if hasattr(record, key) and value is not None: + setattr(record, key, value) + record.updated_at = _now_iso() + if should_persist and self._store is not None: + try: + await self._store.update_run_progress(run_id, **kwargs) + except Exception: + logger.warning("Failed to persist run progress for %s", run_id, exc_info=True) + async def create( self, thread_id: str, diff --git a/backend/packages/harness/deerflow/runtime/runs/store/base.py b/backend/packages/harness/deerflow/runtime/runs/store/base.py index 10c90d7ea..c5ac18212 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/base.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/base.py @@ -95,12 +95,30 @@ class RunStore(abc.ABC): ) -> None: pass + async def update_run_progress( + self, + run_id: str, + *, + total_input_tokens: int | None = None, + total_output_tokens: int | None = None, + total_tokens: int | None = None, + llm_call_count: int | None = None, + lead_agent_tokens: int | None = None, + subagent_tokens: int | None = None, + middleware_tokens: int | None = None, + message_count: int | None = None, + last_ai_message: str | None = None, + first_human_message: str | None = None, + ) -> None: + """Persist a best-effort running snapshot without changing run status.""" + return None + @abc.abstractmethod async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: pass @abc.abstractmethod - async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: + async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: """Aggregate token usage for completed runs in a thread. Returns a dict with keys: total_tokens, total_input_tokens, diff --git a/backend/packages/harness/deerflow/runtime/runs/store/memory.py b/backend/packages/harness/deerflow/runtime/runs/store/memory.py index 56ef02b5b..d241f2ecc 100644 --- a/backend/packages/harness/deerflow/runtime/runs/store/memory.py +++ b/backend/packages/harness/deerflow/runtime/runs/store/memory.py @@ -82,14 +82,22 @@ class MemoryRunStore(RunStore): self._runs[run_id][key] = value self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + async def update_run_progress(self, run_id, **kwargs): + if run_id in self._runs and self._runs[run_id].get("status") == "running": + for key, value in kwargs.items(): + if value is not None: + self._runs[run_id][key] = value + self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() + async def list_pending(self, *, before=None): now = before or datetime.now(UTC).isoformat() results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now] results.sort(key=lambda r: r["created_at"]) return results - async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: - completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")] + async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: + statuses = ("success", "error", "running") if include_active else ("success", "error") + completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses] by_model: dict[str, dict] = {} for r in completed: model = r.get("model_name") or "unknown" diff --git a/backend/packages/harness/deerflow/runtime/runs/worker.py b/backend/packages/harness/deerflow/runtime/runs/worker.py index 694464fe3..d84b3edf9 100644 --- a/backend/packages/harness/deerflow/runtime/runs/worker.py +++ b/backend/packages/harness/deerflow/runtime/runs/worker.py @@ -153,8 +153,6 @@ async def run_agent( journal = None - journal = None - # Track whether "events" was requested but skipped if "events" in requested_modes: logger.info( @@ -177,6 +175,7 @@ async def run_agent( thread_id=thread_id, event_store=event_store, track_token_usage=getattr(run_events_config, "track_token_usage", True), + progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot), ) # 1. Mark running diff --git a/backend/tests/test_run_journal.py b/backend/tests/test_run_journal.py index 8615caa49..0b495954b 100644 --- a/backend/tests/test_run_journal.py +++ b/backend/tests/test_run_journal.py @@ -714,6 +714,110 @@ class TestExternalUsageRecords: assert j._subagent_tokens == 0 +class TestProgressSnapshots: + @pytest.mark.anyio + async def test_on_llm_end_reports_progress_snapshot(self): + snapshots: list[dict] = [] + + async def reporter(snapshot: dict) -> None: + snapshots.append(snapshot) + + store = MemoryRunEventStore() + j = RunJournal( + "r1", + "t1", + store, + flush_threshold=100, + progress_reporter=reporter, + progress_flush_interval=0, + ) + usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15} + j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"]) + await j.flush() + + assert snapshots + assert snapshots[-1]["total_tokens"] == 15 + assert snapshots[-1]["llm_call_count"] == 1 + assert snapshots[-1]["message_count"] == 1 + assert snapshots[-1]["last_ai_message"] == "Answer" + + @pytest.mark.anyio + async def test_throttled_progress_flush_emits_trailing_snapshot(self): + snapshots: list[dict] = [] + trailing_seen = asyncio.Event() + + async def reporter(snapshot: dict) -> None: + snapshots.append(snapshot) + if snapshot["total_tokens"] == 45: + trailing_seen.set() + + store = MemoryRunEventStore() + j = RunJournal( + "r1", + "t1", + store, + flush_threshold=100, + progress_reporter=reporter, + progress_flush_interval=0.01, + ) + j.on_llm_end( + _make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + j.on_llm_end( + _make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + await asyncio.wait_for(trailing_seen.wait(), timeout=1.0) + await j.flush() + + assert len(snapshots) >= 2 + assert snapshots[-1]["total_tokens"] == 45 + assert snapshots[-1]["llm_call_count"] == 2 + assert snapshots[-1]["last_ai_message"] == "Second" + + @pytest.mark.anyio + async def test_flush_cancels_delayed_progress_without_final_progress_write(self): + snapshots: list[dict] = [] + + async def reporter(snapshot: dict) -> None: + snapshots.append(snapshot) + + store = MemoryRunEventStore() + j = RunJournal( + "r1", + "t1", + store, + flush_threshold=100, + progress_reporter=reporter, + progress_flush_interval=10.0, + ) + j.on_llm_end( + _make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + await asyncio.sleep(0) + assert snapshots[-1]["total_tokens"] == 15 + j.on_llm_end( + _make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}), + run_id=uuid4(), + parent_run_id=None, + tags=["lead_agent"], + ) + + await asyncio.wait_for(j.flush(), timeout=0.2) + + assert snapshots[-1]["total_tokens"] == 15 + assert snapshots[-1]["llm_call_count"] == 1 + assert snapshots[-1]["last_ai_message"] == "First" + + class TestChatModelStartHumanMessage: """Tests for on_chat_model_start extracting the first human message.""" diff --git a/backend/tests/test_run_repository.py b/backend/tests/test_run_repository.py index 5809db517..f18e51348 100644 --- a/backend/tests/test_run_repository.py +++ b/backend/tests/test_run_repository.py @@ -10,6 +10,7 @@ from sqlalchemy.dialects import postgresql from deerflow.persistence.run import RunRepository from deerflow.runtime import RunManager, RunStatus +from deerflow.runtime.runs.store.base import RunStore async def _make_repo(tmp_path): @@ -26,6 +27,42 @@ async def _cleanup(): await close_engine() +class _CustomRunStoreWithoutProgress(RunStore): + async def put(self, *args, **kwargs): + return None + + async def get(self, *args, **kwargs): + return None + + async def list_by_thread(self, *args, **kwargs): + return [] + + async def update_status(self, *args, **kwargs): + return None + + async def delete(self, *args, **kwargs): + return None + + async def update_model_name(self, *args, **kwargs): + return None + + async def update_run_completion(self, *args, **kwargs): + return None + + async def list_pending(self, *args, **kwargs): + return [] + + async def aggregate_tokens_by_thread(self, *args, **kwargs): + return {} + + +@pytest.mark.anyio +async def test_update_run_progress_defaults_to_noop_for_custom_store(): + store = _CustomRunStoreWithoutProgress() + + await store.update_run_progress("r1", total_tokens=1) + + class TestRunRepository: @pytest.mark.anyio async def test_put_and_get(self, tmp_path): @@ -170,6 +207,69 @@ class TestRunRepository: assert row["total_tokens"] == 100 await _cleanup() + @pytest.mark.anyio + async def test_update_run_progress_keeps_status_running(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", status="running") + await repo.update_run_progress( + "r1", + total_input_tokens=40, + total_output_tokens=10, + total_tokens=50, + llm_call_count=1, + message_count=2, + last_ai_message="partial answer", + ) + row = await repo.get("r1") + assert row["status"] == "running" + assert row["total_tokens"] == 50 + assert row["llm_call_count"] == 1 + assert row["message_count"] == 2 + assert row["last_ai_message"] == "partial answer" + await _cleanup() + + @pytest.mark.anyio + async def test_update_run_progress_preserves_omitted_fields(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", status="running") + await repo.update_run_progress( + "r1", + total_input_tokens=40, + total_output_tokens=10, + total_tokens=50, + llm_call_count=1, + lead_agent_tokens=30, + subagent_tokens=20, + message_count=2, + ) + + await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated") + + row = await repo.get("r1") + assert row["total_input_tokens"] == 40 + assert row["total_output_tokens"] == 10 + assert row["total_tokens"] == 60 + assert row["llm_call_count"] == 1 + assert row["lead_agent_tokens"] == 30 + assert row["subagent_tokens"] == 20 + assert row["message_count"] == 2 + assert row["last_ai_message"] == "updated" + await _cleanup() + + @pytest.mark.anyio + async def test_update_run_progress_skips_terminal_runs(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("r1", thread_id="t1", status="running") + await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1) + + await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2) + + row = await repo.get("r1") + assert row["status"] == "success" + assert row["total_tokens"] == 100 + assert row["llm_call_count"] == 1 + await _cleanup() + @pytest.mark.anyio async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path): repo = await _make_repo(tmp_path) @@ -225,6 +325,28 @@ class TestRunRepository: } await _cleanup() + @pytest.mark.anyio + async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path): + repo = await _make_repo(tmp_path) + await repo.put("success-run", thread_id="t1", status="running") + await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100) + await repo.put("running-run", thread_id="t1", status="running") + await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5) + + without_active = await repo.aggregate_tokens_by_thread("t1") + with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True) + + assert without_active["total_tokens"] == 100 + assert without_active["total_runs"] == 1 + assert with_active["total_tokens"] == 125 + assert with_active["total_runs"] == 2 + assert with_active["by_caller"] == { + "lead_agent": 120, + "subagent": 5, + "middleware": 0, + } + await _cleanup() + @pytest.mark.anyio async def test_list_by_thread_ordered_desc(self, tmp_path): """list_by_thread returns newest first.""" diff --git a/backend/tests/test_thread_token_usage.py b/backend/tests/test_thread_token_usage.py index 713f6aa5f..19f8e0c19 100644 --- a/backend/tests/test_thread_token_usage.py +++ b/backend/tests/test_thread_token_usage.py @@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape(): }, } run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1") + + +def test_thread_token_usage_can_include_active_runs(): + run_store = MagicMock() + run_store.aggregate_tokens_by_thread = AsyncMock( + return_value={ + "total_tokens": 175, + "total_input_tokens": 120, + "total_output_tokens": 55, + "total_runs": 3, + "by_model": {"unknown": {"tokens": 175, "runs": 3}}, + "by_caller": { + "lead_agent": 145, + "subagent": 25, + "middleware": 5, + }, + }, + ) + app = _make_app(run_store) + + with TestClient(app) as client: + response = client.get("/api/threads/thread-1/token-usage?include_active=true") + + assert response.status_code == 200 + assert response.json()["total_tokens"] == 175 + assert response.json()["total_runs"] == 3 + run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)