perf: use SQL aggregation for feedback stats and thread token usage

Replace Python-side counting in FeedbackRepository.aggregate_by_run with
a single SELECT COUNT/SUM query. Add RunStore.aggregate_tokens_by_thread
abstract method with SQL GROUP BY implementation in RunRepository and
Python fallback in MemoryRunStore. Simplify the thread_token_usage
endpoint to delegate to the new method, eliminating the limit=10000
truncation risk.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-06 11:20:34 +08:00
parent 332fb18b34
commit 0af0ae7fbb
5 changed files with 98 additions and 41 deletions
+2 -29
View File
@@ -310,32 +310,5 @@ async def list_run_events(
async def thread_token_usage(thread_id: str, request: Request) -> dict: async def thread_token_usage(thread_id: str, request: Request) -> dict:
"""Thread-level token usage aggregation.""" """Thread-level token usage aggregation."""
run_store = get_run_store(request) run_store = get_run_store(request)
runs = await run_store.list_by_thread(thread_id, limit=10000) agg = await run_store.aggregate_tokens_by_thread(thread_id)
completed = [r for r in runs if r.get("status") in ("success", "error")] return {"thread_id": thread_id, **agg}
total_tokens = sum(r.get("total_tokens", 0) for r in completed)
total_input = sum(r.get("total_input_tokens", 0) for r in completed)
total_output = sum(r.get("total_output_tokens", 0) for r in completed)
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
entry["tokens"] += r.get("total_tokens", 0)
entry["runs"] += 1
by_caller = {
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
}
return {
"thread_id": thread_id,
"total_tokens": total_tokens,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_runs": len(completed),
"by_model": by_model,
"by_caller": by_caller,
}
@@ -8,7 +8,7 @@ from __future__ import annotations
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
from sqlalchemy import select from sqlalchemy import case, func, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.feedback import FeedbackRow from deerflow.persistence.models.feedback import FeedbackRow
@@ -82,13 +82,17 @@ class FeedbackRepository:
return True return True
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict: async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
"""Aggregate feedback stats for a run.""" """Aggregate feedback stats for a run using database-side counting."""
items = await self.list_by_run(thread_id, run_id, limit=10000) stmt = select(
positive = sum(1 for i in items if i["rating"] == 1) func.count().label("total"),
negative = sum(1 for i in items if i["rating"] == -1) func.coalesce(func.sum(case((FeedbackRow.rating == 1, 1), else_=0)), 0).label("positive"),
return { func.coalesce(func.sum(case((FeedbackRow.rating == -1, 1), else_=0)), 0).label("negative"),
"run_id": run_id, ).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id)
"total": len(items), async with self._sf() as session:
"positive": positive, row = (await session.execute(stmt)).one()
"negative": negative, return {
} "run_id": run_id,
"total": row.total,
"positive": row.positive,
"negative": row.negative,
}
@@ -11,7 +11,7 @@ import json
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
from sqlalchemy import select, update from sqlalchemy import func, select, update
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run import RunRow from deerflow.persistence.models.run import RunRow
@@ -171,3 +171,52 @@ class RunRepository(RunStore):
async with self._sf() as session: async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit() await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
_thread = RunRow.thread_id == thread_id
stmt = (
select(
func.coalesce(RunRow.model_name, "unknown").label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
func.coalesce(func.sum(RunRow.total_output_tokens), 0).label("total_output_tokens"),
func.coalesce(func.sum(RunRow.lead_agent_tokens), 0).label("lead_agent"),
func.coalesce(func.sum(RunRow.subagent_tokens), 0).label("subagent"),
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
)
.where(_thread, _completed)
.group_by(func.coalesce(RunRow.model_name, "unknown"))
)
async with self._sf() as session:
rows = (await session.execute(stmt)).all()
total_tokens = total_input = total_output = total_runs = 0
lead_agent = subagent = middleware = 0
by_model: dict[str, dict] = {}
for r in rows:
by_model[r.model] = {"tokens": r.total_tokens, "runs": r.runs}
total_tokens += r.total_tokens
total_input += r.total_input_tokens
total_output += r.total_output_tokens
total_runs += r.runs
lead_agent += r.lead_agent
subagent += r.subagent
middleware += r.middleware
return {
"total_tokens": total_tokens,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_runs": total_runs,
"by_model": by_model,
"by_caller": {
"lead_agent": lead_agent,
"subagent": subagent,
"middleware": middleware,
},
}
@@ -84,3 +84,13 @@ class RunStore(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens,
total_output_tokens, total_runs, by_model (model_name → {tokens, runs}),
by_caller ({lead_agent, subagent, middleware}).
"""
pass
@@ -77,3 +77,24 @@ class MemoryRunStore(RunStore):
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now] results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"]) results.sort(key=lambda r: r["created_at"])
return results return results
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
by_model: dict[str, dict] = {}
for r in completed:
model = r.get("model_name") or "unknown"
entry = by_model.setdefault(model, {"tokens": 0, "runs": 0})
entry["tokens"] += r.get("total_tokens", 0)
entry["runs"] += 1
return {
"total_tokens": sum(r.get("total_tokens", 0) for r in completed),
"total_input_tokens": sum(r.get("total_input_tokens", 0) for r in completed),
"total_output_tokens": sum(r.get("total_output_tokens", 0) for r in completed),
"total_runs": len(completed),
"by_model": by_model,
"by_caller": {
"lead_agent": sum(r.get("lead_agent_tokens", 0) for r in completed),
"subagent": sum(r.get("subagent_tokens", 0) for r in completed),
"middleware": sum(r.get("middleware_tokens", 0) for r in completed),
},
}