fix(runs): expose active progress counters (#3148)

* fix(runs): expose active progress counters

* fix(runs): avoid delayed progress flush on completion

* fix(runs): tighten progress snapshot semantics

* fix(runs): preserve omitted progress fields

* chore(runs): remove duplicate journal initialization
This commit is contained in:
Lawrance_YXLiao
2026-05-22 21:42:14 +08:00
committed by GitHub
parent 914d6a4f1c
commit 2eeb597985
10 changed files with 468 additions and 10 deletions
+25 -2
View File
@@ -66,6 +66,14 @@ class RunResponse(BaseModel):
multitask_strategy: str = "reject" multitask_strategy: str = "reject"
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
class ThreadTokenUsageModelBreakdown(BaseModel): class ThreadTokenUsageModelBreakdown(BaseModel):
@@ -111,6 +119,14 @@ def _record_to_response(record: RunRecord) -> RunResponse:
multitask_strategy=record.multitask_strategy, multitask_strategy=record.multitask_strategy,
created_at=record.created_at, created_at=record.created_at,
updated_at=record.updated_at, updated_at=record.updated_at,
total_input_tokens=record.total_input_tokens,
total_output_tokens=record.total_output_tokens,
total_tokens=record.total_tokens,
llm_call_count=record.llm_call_count,
lead_agent_tokens=record.lead_agent_tokens,
subagent_tokens=record.subagent_tokens,
middleware_tokens=record.middleware_tokens,
message_count=record.message_count,
) )
@@ -402,8 +418,15 @@ async def list_run_events(
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse) @router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
@require_permission("threads", "read", owner_check=True) @require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse: async def thread_token_usage(
thread_id: str,
request: Request,
include_active: bool = Query(default=False, description="Include running run progress snapshots"),
) -> ThreadTokenUsageResponse:
"""Thread-level token usage aggregation.""" """Thread-level token usage aggregation."""
run_store = get_run_store(request) run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id) if include_active:
agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True)
else:
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return ThreadTokenUsageResponse(thread_id=thread_id, **agg) return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
@@ -227,9 +227,48 @@ class RunRepository(RunStore):
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit() await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Update token usage + convenience fields while a run is still active."""
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
optional_counters = {
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
for key, value in optional_counters.items():
if value is not None:
values[key] = value
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query.""" """Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error")) statuses = ("success", "error", "running") if include_active else ("success", "error")
_completed = RunRow.status.in_(statuses)
_thread = RunRow.thread_id == thread_id _thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown") model_name = func.coalesce(RunRow.model_name, "unknown")
@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time import time
from collections.abc import Mapping from collections.abc import Awaitable, Callable, Mapping
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import UUID from uuid import UUID
@@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler):
*, *,
track_token_usage: bool = True, track_token_usage: bool = True,
flush_threshold: int = 20, flush_threshold: int = 20,
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
progress_flush_interval: float = 5.0,
): ):
super().__init__() super().__init__()
self.run_id = run_id self.run_id = run_id
@@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler):
self._store = event_store self._store = event_store
self._track_tokens = track_token_usage self._track_tokens = track_token_usage
self._flush_threshold = flush_threshold self._flush_threshold = flush_threshold
self._progress_reporter = progress_reporter
self._progress_flush_interval = progress_flush_interval
# Write buffer # Write buffer
self._buffer: list[dict] = [] self._buffer: list[dict] = []
self._pending_flush_tasks: set[asyncio.Task[None]] = set() self._pending_flush_tasks: set[asyncio.Task[None]] = set()
self._pending_progress_task: asyncio.Task[None] | None = None
self._pending_progress_delayed = False
self._progress_dirty = False
self._last_progress_flush = 0.0
# Token accumulators # Token accumulators
self._total_input_tokens = 0 self._total_input_tokens = 0
@@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler):
else: else:
self._lead_agent_tokens += total_tk self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
if messages: if messages:
self._counted_message_llm_run_ids.add(str(run_id)) self._counted_message_llm_run_ids.add(str(run_id))
@@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler):
else: else:
self._lead_agent_tokens += total_tk self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
def set_first_human_message(self, content: str) -> None: def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields.""" """Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None self._first_human_msg = content[:2000] if content else None
@@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler):
"""Force flush remaining buffer. Called in worker's finally block.""" """Force flush remaining buffer. Called in worker's finally block."""
if self._pending_flush_tasks: if self._pending_flush_tasks:
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True) await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
while self._pending_progress_task is not None and not self._pending_progress_task.done():
if self._pending_progress_delayed:
self._pending_progress_task.cancel()
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
self._progress_dirty = False
self._pending_progress_delayed = False
break
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
while self._buffer: while self._buffer:
batch = self._buffer[: self._flush_threshold] batch = self._buffer[: self._flush_threshold]
@@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler):
self._buffer = batch + self._buffer self._buffer = batch + self._buffer
raise raise
def _schedule_progress_flush(self) -> None:
"""Best-effort throttled progress snapshot for active run visibility."""
if self._progress_reporter is None:
return
now = time.monotonic()
elapsed = now - self._last_progress_flush
if elapsed < self._progress_flush_interval:
self._progress_dirty = True
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
return
if self._pending_progress_task is not None and not self._pending_progress_task.done():
self._progress_dirty = True
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
self._progress_dirty = False
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
def _schedule_delayed_progress_flush(self, delay: float) -> None:
if self._pending_progress_task is not None and not self._pending_progress_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
delay = max(0.0, delay)
self._pending_progress_delayed = delay > 0
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
if self._progress_reporter is None:
return
if delay > 0:
self._pending_progress_delayed = True
await asyncio.sleep(delay)
self._pending_progress_delayed = False
dirty_before_write = self._progress_dirty
self._progress_dirty = False
snapshot_to_write = snapshot or self.get_completion_data()
try:
await self._progress_reporter(snapshot_to_write)
self._last_progress_flush = time.monotonic()
except Exception:
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
if dirty_before_write or self._progress_dirty:
self._progress_dirty = False
self._pending_progress_task = None
self._schedule_delayed_progress_flush(self._progress_flush_interval)
def get_completion_data(self) -> dict: def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion.""" """Return accumulated token and message data for run completion."""
return { return {
@@ -38,6 +38,16 @@ class RunRecord:
error: str | None = None error: str | None = None
model_name: str | None = None model_name: str | None = None
store_only: bool = False store_only: bool = False
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
last_ai_message: str | None = None
first_human_message: str | None = None
class RunManager: class RunManager:
@@ -102,16 +112,53 @@ class RunManager:
error=row.get("error"), error=row.get("error"),
model_name=row.get("model_name"), model_name=row.get("model_name"),
store_only=True, store_only=True,
total_input_tokens=row.get("total_input_tokens") or 0,
total_output_tokens=row.get("total_output_tokens") or 0,
total_tokens=row.get("total_tokens") or 0,
llm_call_count=row.get("llm_call_count") or 0,
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
subagent_tokens=row.get("subagent_tokens") or 0,
middleware_tokens=row.get("middleware_tokens") or 0,
message_count=row.get("message_count") or 0,
last_ai_message=row.get("last_ai_message"),
first_human_message=row.get("first_human_message"),
) )
async def update_run_completion(self, run_id: str, **kwargs) -> None: async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store.""" """Persist token usage and completion data to the backing store."""
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
for key, value in kwargs.items():
if key == "status":
continue
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if self._store is not None: if self._store is not None:
try: try:
await self._store.update_run_completion(run_id, **kwargs) await self._store.update_run_completion(run_id, **kwargs)
except Exception: except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def update_run_progress(self, run_id: str, **kwargs) -> None:
"""Persist a running token/message snapshot without changing status."""
should_persist = True
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
should_persist = record.status == RunStatus.running
if record is not None and should_persist:
for key, value in kwargs.items():
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if should_persist and self._store is not None:
try:
await self._store.update_run_progress(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
async def create( async def create(
self, self,
thread_id: str, thread_id: str,
@@ -95,12 +95,30 @@ class RunStore(abc.ABC):
) -> None: ) -> None:
pass pass
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Persist a best-effort running snapshot without changing run status."""
return None
@abc.abstractmethod @abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread. """Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens, Returns a dict with keys: total_tokens, total_input_tokens,
@@ -82,14 +82,22 @@ class MemoryRunStore(RunStore):
self._runs[run_id][key] = value self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running":
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None): async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat() now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now] results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"]) results.sort(key=lambda r: r["created_at"])
return results return results
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")] statuses = ("success", "error", "running") if include_active else ("success", "error")
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
by_model: dict[str, dict] = {} by_model: dict[str, dict] = {}
for r in completed: for r in completed:
model = r.get("model_name") or "unknown" model = r.get("model_name") or "unknown"
@@ -153,8 +153,6 @@ async def run_agent(
journal = None journal = None
journal = None
# Track whether "events" was requested but skipped # Track whether "events" was requested but skipped
if "events" in requested_modes: if "events" in requested_modes:
logger.info( logger.info(
@@ -177,6 +175,7 @@ async def run_agent(
thread_id=thread_id, thread_id=thread_id,
event_store=event_store, event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True), track_token_usage=getattr(run_events_config, "track_token_usage", True),
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
) )
# 1. Mark running # 1. Mark running
+104
View File
@@ -714,6 +714,110 @@ class TestExternalUsageRecords:
assert j._subagent_tokens == 0 assert j._subagent_tokens == 0
class TestProgressSnapshots:
@pytest.mark.anyio
async def test_on_llm_end_reports_progress_snapshot(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0,
)
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
await j.flush()
assert snapshots
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["message_count"] == 1
assert snapshots[-1]["last_ai_message"] == "Answer"
@pytest.mark.anyio
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
snapshots: list[dict] = []
trailing_seen = asyncio.Event()
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
if snapshot["total_tokens"] == 45:
trailing_seen.set()
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0.01,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
await j.flush()
assert len(snapshots) >= 2
assert snapshots[-1]["total_tokens"] == 45
assert snapshots[-1]["llm_call_count"] == 2
assert snapshots[-1]["last_ai_message"] == "Second"
@pytest.mark.anyio
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=10.0,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.sleep(0)
assert snapshots[-1]["total_tokens"] == 15
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(j.flush(), timeout=0.2)
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["last_ai_message"] == "First"
class TestChatModelStartHumanMessage: class TestChatModelStartHumanMessage:
"""Tests for on_chat_model_start extracting the first human message.""" """Tests for on_chat_model_start extracting the first human message."""
+122
View File
@@ -10,6 +10,7 @@ from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository from deerflow.persistence.run import RunRepository
from deerflow.runtime import RunManager, RunStatus from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.store.base import RunStore
async def _make_repo(tmp_path): async def _make_repo(tmp_path):
@@ -26,6 +27,42 @@ async def _cleanup():
await close_engine() await close_engine()
class _CustomRunStoreWithoutProgress(RunStore):
async def put(self, *args, **kwargs):
return None
async def get(self, *args, **kwargs):
return None
async def list_by_thread(self, *args, **kwargs):
return []
async def update_status(self, *args, **kwargs):
return None
async def delete(self, *args, **kwargs):
return None
async def update_model_name(self, *args, **kwargs):
return None
async def update_run_completion(self, *args, **kwargs):
return None
async def list_pending(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {}
@pytest.mark.anyio
async def test_update_run_progress_defaults_to_noop_for_custom_store():
store = _CustomRunStoreWithoutProgress()
await store.update_run_progress("r1", total_tokens=1)
class TestRunRepository: class TestRunRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_and_get(self, tmp_path): async def test_put_and_get(self, tmp_path):
@@ -170,6 +207,69 @@ class TestRunRepository:
assert row["total_tokens"] == 100 assert row["total_tokens"] == 100
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_keeps_status_running(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
message_count=2,
last_ai_message="partial answer",
)
row = await repo.get("r1")
assert row["status"] == "running"
assert row["total_tokens"] == 50
assert row["llm_call_count"] == 1
assert row["message_count"] == 2
assert row["last_ai_message"] == "partial answer"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
lead_agent_tokens=30,
subagent_tokens=20,
message_count=2,
)
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
row = await repo.get("r1")
assert row["total_input_tokens"] == 40
assert row["total_output_tokens"] == 10
assert row["total_tokens"] == 60
assert row["llm_call_count"] == 1
assert row["lead_agent_tokens"] == 30
assert row["subagent_tokens"] == 20
assert row["message_count"] == 2
assert row["last_ai_message"] == "updated"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
row = await repo.get("r1")
assert row["status"] == "success"
assert row["total_tokens"] == 100
assert row["llm_call_count"] == 1
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path): async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -225,6 +325,28 @@ class TestRunRepository:
} }
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
without_active = await repo.aggregate_tokens_by_thread("t1")
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
assert without_active["total_tokens"] == 100
assert without_active["total_runs"] == 1
assert with_active["total_tokens"] == 125
assert with_active["total_runs"] == 2
assert with_active["by_caller"] == {
"lead_agent": 120,
"subagent": 5,
"middleware": 0,
}
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path): async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first.""" """list_by_thread returns newest first."""
+27
View File
@@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape():
}, },
} }
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1") run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
def test_thread_token_usage_can_include_active_runs():
run_store = MagicMock()
run_store.aggregate_tokens_by_thread = AsyncMock(
return_value={
"total_tokens": 175,
"total_input_tokens": 120,
"total_output_tokens": 55,
"total_runs": 3,
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
"by_caller": {
"lead_agent": 145,
"subagent": 25,
"middleware": 5,
},
},
)
app = _make_app(run_store)
with TestClient(app) as client:
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
assert response.status_code == 200
assert response.json()["total_tokens"] == 175
assert response.json()["total_runs"] == 3
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)