mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
fix(runs): expose active progress counters (#3148)
* fix(runs): expose active progress counters * fix(runs): avoid delayed progress flush on completion * fix(runs): tighten progress snapshot semantics * fix(runs): preserve omitted progress fields * chore(runs): remove duplicate journal initialization
This commit is contained in:
@@ -20,7 +20,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from datetime import UTC, datetime
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from uuid import UUID
|
||||
@@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
*,
|
||||
track_token_usage: bool = True,
|
||||
flush_threshold: int = 20,
|
||||
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
|
||||
progress_flush_interval: float = 5.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.run_id = run_id
|
||||
@@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._store = event_store
|
||||
self._track_tokens = track_token_usage
|
||||
self._flush_threshold = flush_threshold
|
||||
self._progress_reporter = progress_reporter
|
||||
self._progress_flush_interval = progress_flush_interval
|
||||
|
||||
# Write buffer
|
||||
self._buffer: list[dict] = []
|
||||
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
||||
self._pending_progress_task: asyncio.Task[None] | None = None
|
||||
self._pending_progress_delayed = False
|
||||
self._progress_dirty = False
|
||||
self._last_progress_flush = 0.0
|
||||
|
||||
# Token accumulators
|
||||
self._total_input_tokens = 0
|
||||
@@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
self._schedule_progress_flush()
|
||||
|
||||
if messages:
|
||||
self._counted_message_llm_run_ids.add(str(run_id))
|
||||
|
||||
@@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler):
|
||||
else:
|
||||
self._lead_agent_tokens += total_tk
|
||||
|
||||
self._schedule_progress_flush()
|
||||
|
||||
def set_first_human_message(self, content: str) -> None:
|
||||
"""Record the first human message for convenience fields."""
|
||||
self._first_human_msg = content[:2000] if content else None
|
||||
@@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler):
|
||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||
if self._pending_flush_tasks:
|
||||
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
||||
while self._pending_progress_task is not None and not self._pending_progress_task.done():
|
||||
if self._pending_progress_delayed:
|
||||
self._pending_progress_task.cancel()
|
||||
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
|
||||
self._progress_dirty = False
|
||||
self._pending_progress_delayed = False
|
||||
break
|
||||
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
|
||||
|
||||
while self._buffer:
|
||||
batch = self._buffer[: self._flush_threshold]
|
||||
@@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler):
|
||||
self._buffer = batch + self._buffer
|
||||
raise
|
||||
|
||||
def _schedule_progress_flush(self) -> None:
|
||||
"""Best-effort throttled progress snapshot for active run visibility."""
|
||||
if self._progress_reporter is None:
|
||||
return
|
||||
now = time.monotonic()
|
||||
elapsed = now - self._last_progress_flush
|
||||
if elapsed < self._progress_flush_interval:
|
||||
self._progress_dirty = True
|
||||
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
|
||||
return
|
||||
if self._pending_progress_task is not None and not self._pending_progress_task.done():
|
||||
self._progress_dirty = True
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
self._progress_dirty = False
|
||||
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
|
||||
|
||||
def _schedule_delayed_progress_flush(self, delay: float) -> None:
|
||||
if self._pending_progress_task is not None and not self._pending_progress_task.done():
|
||||
return
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return
|
||||
delay = max(0.0, delay)
|
||||
self._pending_progress_delayed = delay > 0
|
||||
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
|
||||
|
||||
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
|
||||
if self._progress_reporter is None:
|
||||
return
|
||||
if delay > 0:
|
||||
self._pending_progress_delayed = True
|
||||
await asyncio.sleep(delay)
|
||||
self._pending_progress_delayed = False
|
||||
dirty_before_write = self._progress_dirty
|
||||
self._progress_dirty = False
|
||||
snapshot_to_write = snapshot or self.get_completion_data()
|
||||
try:
|
||||
await self._progress_reporter(snapshot_to_write)
|
||||
self._last_progress_flush = time.monotonic()
|
||||
except Exception:
|
||||
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
|
||||
if dirty_before_write or self._progress_dirty:
|
||||
self._progress_dirty = False
|
||||
self._pending_progress_task = None
|
||||
self._schedule_delayed_progress_flush(self._progress_flush_interval)
|
||||
|
||||
def get_completion_data(self) -> dict:
|
||||
"""Return accumulated token and message data for run completion."""
|
||||
return {
|
||||
|
||||
@@ -38,6 +38,16 @@ class RunRecord:
|
||||
error: str | None = None
|
||||
model_name: str | None = None
|
||||
store_only: bool = False
|
||||
total_input_tokens: int = 0
|
||||
total_output_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
llm_call_count: int = 0
|
||||
lead_agent_tokens: int = 0
|
||||
subagent_tokens: int = 0
|
||||
middleware_tokens: int = 0
|
||||
message_count: int = 0
|
||||
last_ai_message: str | None = None
|
||||
first_human_message: str | None = None
|
||||
|
||||
|
||||
class RunManager:
|
||||
@@ -102,16 +112,53 @@ class RunManager:
|
||||
error=row.get("error"),
|
||||
model_name=row.get("model_name"),
|
||||
store_only=True,
|
||||
total_input_tokens=row.get("total_input_tokens") or 0,
|
||||
total_output_tokens=row.get("total_output_tokens") or 0,
|
||||
total_tokens=row.get("total_tokens") or 0,
|
||||
llm_call_count=row.get("llm_call_count") or 0,
|
||||
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
|
||||
subagent_tokens=row.get("subagent_tokens") or 0,
|
||||
middleware_tokens=row.get("middleware_tokens") or 0,
|
||||
message_count=row.get("message_count") or 0,
|
||||
last_ai_message=row.get("last_ai_message"),
|
||||
first_human_message=row.get("first_human_message"),
|
||||
)
|
||||
|
||||
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
||||
"""Persist token usage and completion data to the backing store."""
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is not None:
|
||||
for key, value in kwargs.items():
|
||||
if key == "status":
|
||||
continue
|
||||
if hasattr(record, key) and value is not None:
|
||||
setattr(record, key, value)
|
||||
record.updated_at = _now_iso()
|
||||
if self._store is not None:
|
||||
try:
|
||||
await self._store.update_run_completion(run_id, **kwargs)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||
|
||||
async def update_run_progress(self, run_id: str, **kwargs) -> None:
|
||||
"""Persist a running token/message snapshot without changing status."""
|
||||
should_persist = True
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is not None:
|
||||
should_persist = record.status == RunStatus.running
|
||||
if record is not None and should_persist:
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(record, key) and value is not None:
|
||||
setattr(record, key, value)
|
||||
record.updated_at = _now_iso()
|
||||
if should_persist and self._store is not None:
|
||||
try:
|
||||
await self._store.update_run_progress(run_id, **kwargs)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
|
||||
|
||||
async def create(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
||||
@@ -95,12 +95,30 @@ class RunStore(abc.ABC):
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def update_run_progress(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
total_input_tokens: int | None = None,
|
||||
total_output_tokens: int | None = None,
|
||||
total_tokens: int | None = None,
|
||||
llm_call_count: int | None = None,
|
||||
lead_agent_tokens: int | None = None,
|
||||
subagent_tokens: int | None = None,
|
||||
middleware_tokens: int | None = None,
|
||||
message_count: int | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
) -> None:
|
||||
"""Persist a best-effort running snapshot without changing run status."""
|
||||
return None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
"""Aggregate token usage for completed runs in a thread.
|
||||
|
||||
Returns a dict with keys: total_tokens, total_input_tokens,
|
||||
|
||||
@@ -82,14 +82,22 @@ class MemoryRunStore(RunStore):
|
||||
self._runs[run_id][key] = value
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
async def update_run_progress(self, run_id, **kwargs):
|
||||
if run_id in self._runs and self._runs[run_id].get("status") == "running":
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
self._runs[run_id][key] = value
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
|
||||
async def list_pending(self, *, before=None):
|
||||
now = before or datetime.now(UTC).isoformat()
|
||||
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
|
||||
results.sort(key=lambda r: r["created_at"])
|
||||
return results
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
statuses = ("success", "error", "running") if include_active else ("success", "error")
|
||||
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
|
||||
by_model: dict[str, dict] = {}
|
||||
for r in completed:
|
||||
model = r.get("model_name") or "unknown"
|
||||
|
||||
@@ -153,8 +153,6 @@ async def run_agent(
|
||||
|
||||
journal = None
|
||||
|
||||
journal = None
|
||||
|
||||
# Track whether "events" was requested but skipped
|
||||
if "events" in requested_modes:
|
||||
logger.info(
|
||||
@@ -177,6 +175,7 @@ async def run_agent(
|
||||
thread_id=thread_id,
|
||||
event_store=event_store,
|
||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
|
||||
)
|
||||
|
||||
# 1. Mark running
|
||||
|
||||
Reference in New Issue
Block a user