fix(runs): expose active progress counters (#3148)

* fix(runs): expose active progress counters

* fix(runs): avoid delayed progress flush on completion

* fix(runs): tighten progress snapshot semantics

* fix(runs): preserve omitted progress fields

* chore(runs): remove duplicate journal initialization
This commit is contained in:
Lawrance_YXLiao
2026-05-22 21:42:14 +08:00
committed by GitHub
parent 914d6a4f1c
commit 2eeb597985
10 changed files with 468 additions and 10 deletions
@@ -227,9 +227,48 @@ class RunRepository(RunStore):
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Update token usage + convenience fields while a run is still active."""
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
optional_counters = {
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
for key, value in optional_counters.items():
if value is not None:
values[key] = value
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
statuses = ("success", "error", "running") if include_active else ("success", "error")
_completed = RunRow.status.in_(statuses)
_thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown")