fix(runs): expose active progress counters (#3148)

* fix(runs): expose active progress counters

* fix(runs): avoid delayed progress flush on completion

* fix(runs): tighten progress snapshot semantics

* fix(runs): preserve omitted progress fields

* chore(runs): remove duplicate journal initialization
This commit is contained in:
Lawrance_YXLiao
2026-05-22 21:42:14 +08:00
committed by GitHub
parent 914d6a4f1c
commit 2eeb597985
10 changed files with 468 additions and 10 deletions
@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import Mapping
from collections.abc import Awaitable, Callable, Mapping
from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast
from uuid import UUID
@@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler):
*,
track_token_usage: bool = True,
flush_threshold: int = 20,
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
progress_flush_interval: float = 5.0,
):
super().__init__()
self.run_id = run_id
@@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler):
self._store = event_store
self._track_tokens = track_token_usage
self._flush_threshold = flush_threshold
self._progress_reporter = progress_reporter
self._progress_flush_interval = progress_flush_interval
# Write buffer
self._buffer: list[dict] = []
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
self._pending_progress_task: asyncio.Task[None] | None = None
self._pending_progress_delayed = False
self._progress_dirty = False
self._last_progress_flush = 0.0
# Token accumulators
self._total_input_tokens = 0
@@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler):
else:
self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
if messages:
self._counted_message_llm_run_ids.add(str(run_id))
@@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler):
else:
self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None
@@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler):
"""Force flush remaining buffer. Called in worker's finally block."""
if self._pending_flush_tasks:
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
while self._pending_progress_task is not None and not self._pending_progress_task.done():
if self._pending_progress_delayed:
self._pending_progress_task.cancel()
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
self._progress_dirty = False
self._pending_progress_delayed = False
break
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
while self._buffer:
batch = self._buffer[: self._flush_threshold]
@@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler):
self._buffer = batch + self._buffer
raise
def _schedule_progress_flush(self) -> None:
"""Best-effort throttled progress snapshot for active run visibility."""
if self._progress_reporter is None:
return
now = time.monotonic()
elapsed = now - self._last_progress_flush
if elapsed < self._progress_flush_interval:
self._progress_dirty = True
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
return
if self._pending_progress_task is not None and not self._pending_progress_task.done():
self._progress_dirty = True
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
self._progress_dirty = False
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
def _schedule_delayed_progress_flush(self, delay: float) -> None:
if self._pending_progress_task is not None and not self._pending_progress_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
delay = max(0.0, delay)
self._pending_progress_delayed = delay > 0
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
if self._progress_reporter is None:
return
if delay > 0:
self._pending_progress_delayed = True
await asyncio.sleep(delay)
self._pending_progress_delayed = False
dirty_before_write = self._progress_dirty
self._progress_dirty = False
snapshot_to_write = snapshot or self.get_completion_data()
try:
await self._progress_reporter(snapshot_to_write)
self._last_progress_flush = time.monotonic()
except Exception:
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
if dirty_before_write or self._progress_dirty:
self._progress_dirty = False
self._pending_progress_task = None
self._schedule_delayed_progress_flush(self._progress_flush_interval)
def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion."""
return {
@@ -38,6 +38,16 @@ class RunRecord:
error: str | None = None
model_name: str | None = None
store_only: bool = False
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
last_ai_message: str | None = None
first_human_message: str | None = None
class RunManager:
@@ -102,16 +112,53 @@ class RunManager:
error=row.get("error"),
model_name=row.get("model_name"),
store_only=True,
total_input_tokens=row.get("total_input_tokens") or 0,
total_output_tokens=row.get("total_output_tokens") or 0,
total_tokens=row.get("total_tokens") or 0,
llm_call_count=row.get("llm_call_count") or 0,
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
subagent_tokens=row.get("subagent_tokens") or 0,
middleware_tokens=row.get("middleware_tokens") or 0,
message_count=row.get("message_count") or 0,
last_ai_message=row.get("last_ai_message"),
first_human_message=row.get("first_human_message"),
)
async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store."""
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
for key, value in kwargs.items():
if key == "status":
continue
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if self._store is not None:
try:
await self._store.update_run_completion(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def update_run_progress(self, run_id: str, **kwargs) -> None:
"""Persist a running token/message snapshot without changing status."""
should_persist = True
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
should_persist = record.status == RunStatus.running
if record is not None and should_persist:
for key, value in kwargs.items():
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if should_persist and self._store is not None:
try:
await self._store.update_run_progress(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
async def create(
self,
thread_id: str,
@@ -95,12 +95,30 @@ class RunStore(abc.ABC):
) -> None:
pass
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Persist a best-effort running snapshot without changing run status."""
return None
@abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens,
@@ -82,14 +82,22 @@ class MemoryRunStore(RunStore):
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running":
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
statuses = ("success", "error", "running") if include_active else ("success", "error")
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
@@ -153,8 +153,6 @@ async def run_agent(
journal = None
journal = None
# Track whether "events" was requested but skipped
if "events" in requested_modes:
logger.info(
@@ -177,6 +175,7 @@ async def run_agent(
thread_id=thread_id,
event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True),
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
)
# 1. Mark running