mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
fix(persistence): reuse token usage model grouping expression (#2910)
This commit is contained in:
@@ -223,10 +223,11 @@ class RunRepository(RunStore):
|
|||||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||||
_completed = RunRow.status.in_(("success", "error"))
|
_completed = RunRow.status.in_(("success", "error"))
|
||||||
_thread = RunRow.thread_id == thread_id
|
_thread = RunRow.thread_id == thread_id
|
||||||
|
model_name = func.coalesce(RunRow.model_name, "unknown")
|
||||||
|
|
||||||
stmt = (
|
stmt = (
|
||||||
select(
|
select(
|
||||||
func.coalesce(RunRow.model_name, "unknown").label("model"),
|
model_name.label("model"),
|
||||||
func.count().label("runs"),
|
func.count().label("runs"),
|
||||||
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
|
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
|
||||||
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
|
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
|
||||||
@@ -236,7 +237,7 @@ class RunRepository(RunStore):
|
|||||||
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
|
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
|
||||||
)
|
)
|
||||||
.where(_thread, _completed)
|
.where(_thread, _completed)
|
||||||
.group_by(func.coalesce(RunRow.model_name, "unknown"))
|
.group_by(model_name)
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
|
|||||||
@@ -3,7 +3,10 @@
|
|||||||
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
Uses a temp SQLite DB to test ORM-backed CRUD operations.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
from deerflow.persistence.run import RunRepository
|
from deerflow.persistence.run import RunRepository
|
||||||
|
|
||||||
@@ -278,3 +281,48 @@ class TestRunRepository:
|
|||||||
assert row4["model_name"] is None
|
assert row4["model_name"] is None
|
||||||
|
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
|
||||||
|
captured = []
|
||||||
|
|
||||||
|
class FakeResult:
|
||||||
|
def all(self):
|
||||||
|
return []
|
||||||
|
|
||||||
|
class FakeSession:
|
||||||
|
async def execute(self, stmt):
|
||||||
|
captured.append(stmt)
|
||||||
|
return FakeResult()
|
||||||
|
|
||||||
|
class FakeSessionContext:
|
||||||
|
async def __aenter__(self):
|
||||||
|
return FakeSession()
|
||||||
|
|
||||||
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
return None
|
||||||
|
|
||||||
|
repo = RunRepository(lambda: FakeSessionContext())
|
||||||
|
|
||||||
|
agg = await repo.aggregate_tokens_by_thread("t1")
|
||||||
|
assert agg == {
|
||||||
|
"total_tokens": 0,
|
||||||
|
"total_input_tokens": 0,
|
||||||
|
"total_output_tokens": 0,
|
||||||
|
"total_runs": 0,
|
||||||
|
"by_model": {},
|
||||||
|
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
|
||||||
|
}
|
||||||
|
assert len(captured) == 1
|
||||||
|
|
||||||
|
stmt = captured[0]
|
||||||
|
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
|
||||||
|
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
|
||||||
|
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
|
||||||
|
|
||||||
|
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
|
||||||
|
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
|
||||||
|
|
||||||
|
assert select_match is not None
|
||||||
|
assert group_by_match is not None
|
||||||
|
assert select_match.group(1) == group_by_match.group(1)
|
||||||
|
|||||||
Reference in New Issue
Block a user