mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 11a362e5e5 | |||
| 85402405ec | |||
| 43eb643910 | |||
| f3e3a350ce | |||
| 0fae7c9cbb |
@@ -66,14 +66,6 @@ class RunResponse(BaseModel):
|
|||||||
multitask_strategy: str = "reject"
|
multitask_strategy: str = "reject"
|
||||||
created_at: str = ""
|
created_at: str = ""
|
||||||
updated_at: str = ""
|
updated_at: str = ""
|
||||||
total_input_tokens: int = 0
|
|
||||||
total_output_tokens: int = 0
|
|
||||||
total_tokens: int = 0
|
|
||||||
llm_call_count: int = 0
|
|
||||||
lead_agent_tokens: int = 0
|
|
||||||
subagent_tokens: int = 0
|
|
||||||
middleware_tokens: int = 0
|
|
||||||
message_count: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class ThreadTokenUsageModelBreakdown(BaseModel):
|
class ThreadTokenUsageModelBreakdown(BaseModel):
|
||||||
@@ -119,14 +111,6 @@ def _record_to_response(record: RunRecord) -> RunResponse:
|
|||||||
multitask_strategy=record.multitask_strategy,
|
multitask_strategy=record.multitask_strategy,
|
||||||
created_at=record.created_at,
|
created_at=record.created_at,
|
||||||
updated_at=record.updated_at,
|
updated_at=record.updated_at,
|
||||||
total_input_tokens=record.total_input_tokens,
|
|
||||||
total_output_tokens=record.total_output_tokens,
|
|
||||||
total_tokens=record.total_tokens,
|
|
||||||
llm_call_count=record.llm_call_count,
|
|
||||||
lead_agent_tokens=record.lead_agent_tokens,
|
|
||||||
subagent_tokens=record.subagent_tokens,
|
|
||||||
middleware_tokens=record.middleware_tokens,
|
|
||||||
message_count=record.message_count,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -418,15 +402,8 @@ async def list_run_events(
|
|||||||
|
|
||||||
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
|
||||||
@require_permission("threads", "read", owner_check=True)
|
@require_permission("threads", "read", owner_check=True)
|
||||||
async def thread_token_usage(
|
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
|
||||||
thread_id: str,
|
|
||||||
request: Request,
|
|
||||||
include_active: bool = Query(default=False, description="Include running run progress snapshots"),
|
|
||||||
) -> ThreadTokenUsageResponse:
|
|
||||||
"""Thread-level token usage aggregation."""
|
"""Thread-level token usage aggregation."""
|
||||||
run_store = get_run_store(request)
|
run_store = get_run_store(request)
|
||||||
if include_active:
|
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
||||||
agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True)
|
|
||||||
else:
|
|
||||||
agg = await run_store.aggregate_tokens_by_thread(thread_id)
|
|
||||||
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|
return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
|
||||||
|
|||||||
+7
-6
@@ -15,7 +15,6 @@ to the end of the message list as before_model + add_messages reducer would do.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict, deque
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from typing import override
|
from typing import override
|
||||||
|
|
||||||
@@ -110,10 +109,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
This normalizes model-bound causal order before provider serialization while
|
This normalizes model-bound causal order before provider serialization while
|
||||||
preserving already-valid transcripts unchanged.
|
preserving already-valid transcripts unchanged.
|
||||||
"""
|
"""
|
||||||
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
|
tool_messages_by_id: dict[str, ToolMessage] = {}
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage):
|
if isinstance(msg, ToolMessage):
|
||||||
tool_messages_by_id[msg.tool_call_id].append(msg)
|
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
||||||
|
|
||||||
tool_call_ids: set[str] = set()
|
tool_call_ids: set[str] = set()
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
@@ -125,6 +124,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
tool_call_ids.add(tc_id)
|
tool_call_ids.add(tc_id)
|
||||||
|
|
||||||
patched: list = []
|
patched: list = []
|
||||||
|
consumed_tool_msg_ids: set[str] = set()
|
||||||
patch_count = 0
|
patch_count = 0
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
||||||
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
|
|
||||||
for tc in self._message_tool_calls(msg):
|
for tc in self._message_tool_calls(msg):
|
||||||
tc_id = tc.get("id")
|
tc_id = tc.get("id")
|
||||||
if not tc_id:
|
if not tc_id or tc_id in consumed_tool_msg_ids:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
tool_msg_queue = tool_messages_by_id.get(tc_id)
|
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
||||||
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
|
|
||||||
if existing_tool_msg is not None:
|
if existing_tool_msg is not None:
|
||||||
patched.append(existing_tool_msg)
|
patched.append(existing_tool_msg)
|
||||||
|
consumed_tool_msg_ids.add(tc_id)
|
||||||
else:
|
else:
|
||||||
patched.append(
|
patched.append(
|
||||||
ToolMessage(
|
ToolMessage(
|
||||||
@@ -152,6 +152,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
|||||||
status="error",
|
status="error",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
consumed_tool_msg_ids.add(tc_id)
|
||||||
patch_count += 1
|
patch_count += 1
|
||||||
|
|
||||||
if patched == messages:
|
if patched == messages:
|
||||||
|
|||||||
@@ -227,48 +227,9 @@ class RunRepository(RunStore):
|
|||||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def update_run_progress(
|
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||||
self,
|
|
||||||
run_id: str,
|
|
||||||
*,
|
|
||||||
total_input_tokens: int | None = None,
|
|
||||||
total_output_tokens: int | None = None,
|
|
||||||
total_tokens: int | None = None,
|
|
||||||
llm_call_count: int | None = None,
|
|
||||||
lead_agent_tokens: int | None = None,
|
|
||||||
subagent_tokens: int | None = None,
|
|
||||||
middleware_tokens: int | None = None,
|
|
||||||
message_count: int | None = None,
|
|
||||||
last_ai_message: str | None = None,
|
|
||||||
first_human_message: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Update token usage + convenience fields while a run is still active."""
|
|
||||||
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
|
|
||||||
optional_counters = {
|
|
||||||
"total_input_tokens": total_input_tokens,
|
|
||||||
"total_output_tokens": total_output_tokens,
|
|
||||||
"total_tokens": total_tokens,
|
|
||||||
"llm_call_count": llm_call_count,
|
|
||||||
"lead_agent_tokens": lead_agent_tokens,
|
|
||||||
"subagent_tokens": subagent_tokens,
|
|
||||||
"middleware_tokens": middleware_tokens,
|
|
||||||
"message_count": message_count,
|
|
||||||
}
|
|
||||||
for key, value in optional_counters.items():
|
|
||||||
if value is not None:
|
|
||||||
values[key] = value
|
|
||||||
if last_ai_message is not None:
|
|
||||||
values["last_ai_message"] = last_ai_message[:2000]
|
|
||||||
if first_human_message is not None:
|
|
||||||
values["first_human_message"] = first_human_message[:2000]
|
|
||||||
async with self._sf() as session:
|
|
||||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
|
||||||
"""Aggregate token usage via a single SQL GROUP BY query."""
|
"""Aggregate token usage via a single SQL GROUP BY query."""
|
||||||
statuses = ("success", "error", "running") if include_active else ("success", "error")
|
_completed = RunRow.status.in_(("success", "error"))
|
||||||
_completed = RunRow.status.in_(statuses)
|
|
||||||
_thread = RunRow.thread_id == thread_id
|
_thread = RunRow.thread_id == thread_id
|
||||||
model_name = func.coalesce(RunRow.model_name, "unknown")
|
model_name = func.coalesce(RunRow.model_name, "unknown")
|
||||||
|
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable, Mapping
|
from collections.abc import Mapping
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
@@ -46,8 +46,6 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
*,
|
*,
|
||||||
track_token_usage: bool = True,
|
track_token_usage: bool = True,
|
||||||
flush_threshold: int = 20,
|
flush_threshold: int = 20,
|
||||||
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
|
|
||||||
progress_flush_interval: float = 5.0,
|
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.run_id = run_id
|
self.run_id = run_id
|
||||||
@@ -55,16 +53,10 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
self._store = event_store
|
self._store = event_store
|
||||||
self._track_tokens = track_token_usage
|
self._track_tokens = track_token_usage
|
||||||
self._flush_threshold = flush_threshold
|
self._flush_threshold = flush_threshold
|
||||||
self._progress_reporter = progress_reporter
|
|
||||||
self._progress_flush_interval = progress_flush_interval
|
|
||||||
|
|
||||||
# Write buffer
|
# Write buffer
|
||||||
self._buffer: list[dict] = []
|
self._buffer: list[dict] = []
|
||||||
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
self._pending_flush_tasks: set[asyncio.Task[None]] = set()
|
||||||
self._pending_progress_task: asyncio.Task[None] | None = None
|
|
||||||
self._pending_progress_delayed = False
|
|
||||||
self._progress_dirty = False
|
|
||||||
self._last_progress_flush = 0.0
|
|
||||||
|
|
||||||
# Token accumulators
|
# Token accumulators
|
||||||
self._total_input_tokens = 0
|
self._total_input_tokens = 0
|
||||||
@@ -302,8 +294,6 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
else:
|
else:
|
||||||
self._lead_agent_tokens += total_tk
|
self._lead_agent_tokens += total_tk
|
||||||
|
|
||||||
self._schedule_progress_flush()
|
|
||||||
|
|
||||||
if messages:
|
if messages:
|
||||||
self._counted_message_llm_run_ids.add(str(run_id))
|
self._counted_message_llm_run_ids.add(str(run_id))
|
||||||
|
|
||||||
@@ -455,8 +445,6 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
else:
|
else:
|
||||||
self._lead_agent_tokens += total_tk
|
self._lead_agent_tokens += total_tk
|
||||||
|
|
||||||
self._schedule_progress_flush()
|
|
||||||
|
|
||||||
def set_first_human_message(self, content: str) -> None:
|
def set_first_human_message(self, content: str) -> None:
|
||||||
"""Record the first human message for convenience fields."""
|
"""Record the first human message for convenience fields."""
|
||||||
self._first_human_msg = content[:2000] if content else None
|
self._first_human_msg = content[:2000] if content else None
|
||||||
@@ -486,14 +474,6 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
"""Force flush remaining buffer. Called in worker's finally block."""
|
"""Force flush remaining buffer. Called in worker's finally block."""
|
||||||
if self._pending_flush_tasks:
|
if self._pending_flush_tasks:
|
||||||
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
|
||||||
while self._pending_progress_task is not None and not self._pending_progress_task.done():
|
|
||||||
if self._pending_progress_delayed:
|
|
||||||
self._pending_progress_task.cancel()
|
|
||||||
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
|
|
||||||
self._progress_dirty = False
|
|
||||||
self._pending_progress_delayed = False
|
|
||||||
break
|
|
||||||
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
|
|
||||||
|
|
||||||
while self._buffer:
|
while self._buffer:
|
||||||
batch = self._buffer[: self._flush_threshold]
|
batch = self._buffer[: self._flush_threshold]
|
||||||
@@ -504,57 +484,6 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
self._buffer = batch + self._buffer
|
self._buffer = batch + self._buffer
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def _schedule_progress_flush(self) -> None:
|
|
||||||
"""Best-effort throttled progress snapshot for active run visibility."""
|
|
||||||
if self._progress_reporter is None:
|
|
||||||
return
|
|
||||||
now = time.monotonic()
|
|
||||||
elapsed = now - self._last_progress_flush
|
|
||||||
if elapsed < self._progress_flush_interval:
|
|
||||||
self._progress_dirty = True
|
|
||||||
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
|
|
||||||
return
|
|
||||||
if self._pending_progress_task is not None and not self._pending_progress_task.done():
|
|
||||||
self._progress_dirty = True
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
return
|
|
||||||
self._progress_dirty = False
|
|
||||||
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
|
|
||||||
|
|
||||||
def _schedule_delayed_progress_flush(self, delay: float) -> None:
|
|
||||||
if self._pending_progress_task is not None and not self._pending_progress_task.done():
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
loop = asyncio.get_running_loop()
|
|
||||||
except RuntimeError:
|
|
||||||
return
|
|
||||||
delay = max(0.0, delay)
|
|
||||||
self._pending_progress_delayed = delay > 0
|
|
||||||
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
|
|
||||||
|
|
||||||
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
|
|
||||||
if self._progress_reporter is None:
|
|
||||||
return
|
|
||||||
if delay > 0:
|
|
||||||
self._pending_progress_delayed = True
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
self._pending_progress_delayed = False
|
|
||||||
dirty_before_write = self._progress_dirty
|
|
||||||
self._progress_dirty = False
|
|
||||||
snapshot_to_write = snapshot or self.get_completion_data()
|
|
||||||
try:
|
|
||||||
await self._progress_reporter(snapshot_to_write)
|
|
||||||
self._last_progress_flush = time.monotonic()
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
|
|
||||||
if dirty_before_write or self._progress_dirty:
|
|
||||||
self._progress_dirty = False
|
|
||||||
self._pending_progress_task = None
|
|
||||||
self._schedule_delayed_progress_flush(self._progress_flush_interval)
|
|
||||||
|
|
||||||
def get_completion_data(self) -> dict:
|
def get_completion_data(self) -> dict:
|
||||||
"""Return accumulated token and message data for run completion."""
|
"""Return accumulated token and message data for run completion."""
|
||||||
return {
|
return {
|
||||||
|
|||||||
@@ -38,16 +38,6 @@ class RunRecord:
|
|||||||
error: str | None = None
|
error: str | None = None
|
||||||
model_name: str | None = None
|
model_name: str | None = None
|
||||||
store_only: bool = False
|
store_only: bool = False
|
||||||
total_input_tokens: int = 0
|
|
||||||
total_output_tokens: int = 0
|
|
||||||
total_tokens: int = 0
|
|
||||||
llm_call_count: int = 0
|
|
||||||
lead_agent_tokens: int = 0
|
|
||||||
subagent_tokens: int = 0
|
|
||||||
middleware_tokens: int = 0
|
|
||||||
message_count: int = 0
|
|
||||||
last_ai_message: str | None = None
|
|
||||||
first_human_message: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class RunManager:
|
class RunManager:
|
||||||
@@ -63,24 +53,27 @@ class RunManager:
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._store = store
|
self._store = store
|
||||||
|
|
||||||
async def _persist_to_store(self, record: RunRecord) -> None:
|
async def _persist_new_run_to_store(self, record: RunRecord) -> None:
|
||||||
"""Best-effort persist run record to backing store."""
|
"""Persist a newly created run record to the backing store.
|
||||||
|
|
||||||
|
Initial run creation is part of the run visibility boundary: callers
|
||||||
|
should not observe a run in memory unless its backing store row exists.
|
||||||
|
Unlike follow-up status/model updates, failures are propagated so the
|
||||||
|
caller can treat creation as failed.
|
||||||
|
"""
|
||||||
if self._store is None:
|
if self._store is None:
|
||||||
return
|
return
|
||||||
try:
|
await self._store.put(
|
||||||
await self._store.put(
|
record.run_id,
|
||||||
record.run_id,
|
thread_id=record.thread_id,
|
||||||
thread_id=record.thread_id,
|
assistant_id=record.assistant_id,
|
||||||
assistant_id=record.assistant_id,
|
status=record.status.value,
|
||||||
status=record.status.value,
|
multitask_strategy=record.multitask_strategy,
|
||||||
multitask_strategy=record.multitask_strategy,
|
metadata=record.metadata or {},
|
||||||
metadata=record.metadata or {},
|
kwargs=record.kwargs or {},
|
||||||
kwargs=record.kwargs or {},
|
created_at=record.created_at,
|
||||||
created_at=record.created_at,
|
model_name=record.model_name,
|
||||||
model_name=record.model_name,
|
)
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
|
||||||
|
|
||||||
async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
|
||||||
"""Best-effort persist a status transition to the backing store."""
|
"""Best-effort persist a status transition to the backing store."""
|
||||||
@@ -112,53 +105,16 @@ class RunManager:
|
|||||||
error=row.get("error"),
|
error=row.get("error"),
|
||||||
model_name=row.get("model_name"),
|
model_name=row.get("model_name"),
|
||||||
store_only=True,
|
store_only=True,
|
||||||
total_input_tokens=row.get("total_input_tokens") or 0,
|
|
||||||
total_output_tokens=row.get("total_output_tokens") or 0,
|
|
||||||
total_tokens=row.get("total_tokens") or 0,
|
|
||||||
llm_call_count=row.get("llm_call_count") or 0,
|
|
||||||
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
|
|
||||||
subagent_tokens=row.get("subagent_tokens") or 0,
|
|
||||||
middleware_tokens=row.get("middleware_tokens") or 0,
|
|
||||||
message_count=row.get("message_count") or 0,
|
|
||||||
last_ai_message=row.get("last_ai_message"),
|
|
||||||
first_human_message=row.get("first_human_message"),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
||||||
"""Persist token usage and completion data to the backing store."""
|
"""Persist token usage and completion data to the backing store."""
|
||||||
async with self._lock:
|
|
||||||
record = self._runs.get(run_id)
|
|
||||||
if record is not None:
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if key == "status":
|
|
||||||
continue
|
|
||||||
if hasattr(record, key) and value is not None:
|
|
||||||
setattr(record, key, value)
|
|
||||||
record.updated_at = _now_iso()
|
|
||||||
if self._store is not None:
|
if self._store is not None:
|
||||||
try:
|
try:
|
||||||
await self._store.update_run_completion(run_id, **kwargs)
|
await self._store.update_run_completion(run_id, **kwargs)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||||
|
|
||||||
async def update_run_progress(self, run_id: str, **kwargs) -> None:
|
|
||||||
"""Persist a running token/message snapshot without changing status."""
|
|
||||||
should_persist = True
|
|
||||||
async with self._lock:
|
|
||||||
record = self._runs.get(run_id)
|
|
||||||
if record is not None:
|
|
||||||
should_persist = record.status == RunStatus.running
|
|
||||||
if record is not None and should_persist:
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if hasattr(record, key) and value is not None:
|
|
||||||
setattr(record, key, value)
|
|
||||||
record.updated_at = _now_iso()
|
|
||||||
if should_persist and self._store is not None:
|
|
||||||
try:
|
|
||||||
await self._store.update_run_progress(run_id, **kwargs)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
|
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -186,7 +142,16 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
await self._persist_to_store(record)
|
persisted = False
|
||||||
|
try:
|
||||||
|
await self._persist_new_run_to_store(record)
|
||||||
|
persisted = True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if not persisted:
|
||||||
|
self._runs.pop(run_id, None)
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@@ -364,16 +329,8 @@ class RunManager:
|
|||||||
raise ConflictError(f"Thread {thread_id} already has an active run")
|
raise ConflictError(f"Thread {thread_id} already has an active run")
|
||||||
|
|
||||||
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
||||||
for r in inflight:
|
|
||||||
r.abort_action = multitask_strategy
|
|
||||||
r.abort_event.set()
|
|
||||||
if r.task is not None and not r.task.done():
|
|
||||||
r.task.cancel()
|
|
||||||
r.status = RunStatus.interrupted
|
|
||||||
r.updated_at = now
|
|
||||||
interrupted_run_ids.append(r.run_id)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Cancelled %d inflight run(s) on thread %s (strategy=%s)",
|
"Preparing to cancel %d inflight run(s) on thread %s (strategy=%s)",
|
||||||
len(inflight),
|
len(inflight),
|
||||||
thread_id,
|
thread_id,
|
||||||
multitask_strategy,
|
multitask_strategy,
|
||||||
@@ -393,10 +350,29 @@ class RunManager:
|
|||||||
model_name=model_name,
|
model_name=model_name,
|
||||||
)
|
)
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
|
persisted = False
|
||||||
|
try:
|
||||||
|
await self._persist_new_run_to_store(record)
|
||||||
|
persisted = True
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
if not persisted:
|
||||||
|
self._runs.pop(run_id, None)
|
||||||
|
|
||||||
|
if multitask_strategy in ("interrupt", "rollback") and inflight:
|
||||||
|
for r in inflight:
|
||||||
|
r.abort_action = multitask_strategy
|
||||||
|
r.abort_event.set()
|
||||||
|
if r.task is not None and not r.task.done():
|
||||||
|
r.task.cancel()
|
||||||
|
r.status = RunStatus.interrupted
|
||||||
|
r.updated_at = now
|
||||||
|
interrupted_run_ids.append(r.run_id)
|
||||||
|
|
||||||
for interrupted_run_id in interrupted_run_ids:
|
for interrupted_run_id in interrupted_run_ids:
|
||||||
await self._persist_status(interrupted_run_id, RunStatus.interrupted)
|
await self._persist_status(interrupted_run_id, RunStatus.interrupted)
|
||||||
await self._persist_to_store(record)
|
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
|||||||
@@ -95,30 +95,12 @@ class RunStore(abc.ABC):
|
|||||||
) -> None:
|
) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def update_run_progress(
|
|
||||||
self,
|
|
||||||
run_id: str,
|
|
||||||
*,
|
|
||||||
total_input_tokens: int | None = None,
|
|
||||||
total_output_tokens: int | None = None,
|
|
||||||
total_tokens: int | None = None,
|
|
||||||
llm_call_count: int | None = None,
|
|
||||||
lead_agent_tokens: int | None = None,
|
|
||||||
subagent_tokens: int | None = None,
|
|
||||||
middleware_tokens: int | None = None,
|
|
||||||
message_count: int | None = None,
|
|
||||||
last_ai_message: str | None = None,
|
|
||||||
first_human_message: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Persist a best-effort running snapshot without changing run status."""
|
|
||||||
return None
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
|
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||||
"""Aggregate token usage for completed runs in a thread.
|
"""Aggregate token usage for completed runs in a thread.
|
||||||
|
|
||||||
Returns a dict with keys: total_tokens, total_input_tokens,
|
Returns a dict with keys: total_tokens, total_input_tokens,
|
||||||
|
|||||||
@@ -82,22 +82,14 @@ class MemoryRunStore(RunStore):
|
|||||||
self._runs[run_id][key] = value
|
self._runs[run_id][key] = value
|
||||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
async def update_run_progress(self, run_id, **kwargs):
|
|
||||||
if run_id in self._runs and self._runs[run_id].get("status") == "running":
|
|
||||||
for key, value in kwargs.items():
|
|
||||||
if value is not None:
|
|
||||||
self._runs[run_id][key] = value
|
|
||||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
|
||||||
|
|
||||||
async def list_pending(self, *, before=None):
|
async def list_pending(self, *, before=None):
|
||||||
now = before or datetime.now(UTC).isoformat()
|
now = before or datetime.now(UTC).isoformat()
|
||||||
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
|
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
|
||||||
results.sort(key=lambda r: r["created_at"])
|
results.sort(key=lambda r: r["created_at"])
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||||
statuses = ("success", "error", "running") if include_active else ("success", "error")
|
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
|
||||||
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
|
|
||||||
by_model: dict[str, dict] = {}
|
by_model: dict[str, dict] = {}
|
||||||
for r in completed:
|
for r in completed:
|
||||||
model = r.get("model_name") or "unknown"
|
model = r.get("model_name") or "unknown"
|
||||||
|
|||||||
@@ -153,6 +153,8 @@ async def run_agent(
|
|||||||
|
|
||||||
journal = None
|
journal = None
|
||||||
|
|
||||||
|
journal = None
|
||||||
|
|
||||||
# Track whether "events" was requested but skipped
|
# Track whether "events" was requested but skipped
|
||||||
if "events" in requested_modes:
|
if "events" in requested_modes:
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -175,7 +177,6 @@ async def run_agent(
|
|||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
event_store=event_store,
|
event_store=event_store,
|
||||||
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
track_token_usage=getattr(run_events_config, "track_token_usage", True),
|
||||||
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 1. Mark running
|
# 1. Mark running
|
||||||
|
|||||||
@@ -218,70 +218,6 @@ class TestBuildPatchedMessagesPatching:
|
|||||||
|
|
||||||
assert mw._build_patched_messages(msgs) is None
|
assert mw._build_patched_messages(msgs) is None
|
||||||
|
|
||||||
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
|
|
||||||
mw = DanglingToolCallMiddleware()
|
|
||||||
msgs = [
|
|
||||||
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
|
|
||||||
_ai_with_tool_calls(
|
|
||||||
[
|
|
||||||
_tc("web_search", "web_search:11"),
|
|
||||||
_tc("web_search", "web_search:12"),
|
|
||||||
_tc("web_search", "web_search:13"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
_tool_msg("web_search:11", "web_search"),
|
|
||||||
_tool_msg("web_search:12", "web_search"),
|
|
||||||
_tool_msg("web_search:13", "web_search"),
|
|
||||||
_ai_with_tool_calls(
|
|
||||||
[
|
|
||||||
_tc("web_search", "web_search:9"),
|
|
||||||
_tc("web_search", "web_search:10"),
|
|
||||||
_tc("web_search", "web_search:11"),
|
|
||||||
]
|
|
||||||
),
|
|
||||||
_tool_msg("web_search:9", "web_search"),
|
|
||||||
_tool_msg("web_search:10", "web_search"),
|
|
||||||
_tool_msg("web_search:11", "web_search"),
|
|
||||||
]
|
|
||||||
|
|
||||||
assert mw._build_patched_messages(msgs) is None
|
|
||||||
|
|
||||||
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
|
|
||||||
mw = DanglingToolCallMiddleware()
|
|
||||||
msgs = [
|
|
||||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
|
||||||
_tool_msg("web_search:11", "web_search"),
|
|
||||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
|
||||||
]
|
|
||||||
|
|
||||||
patched = mw._build_patched_messages(msgs)
|
|
||||||
|
|
||||||
assert patched is not None
|
|
||||||
assert isinstance(patched[1], ToolMessage)
|
|
||||||
assert patched[1].tool_call_id == "web_search:11"
|
|
||||||
assert patched[1].status == "success"
|
|
||||||
assert isinstance(patched[3], ToolMessage)
|
|
||||||
assert patched[3].tool_call_id == "web_search:11"
|
|
||||||
assert patched[3].status == "error"
|
|
||||||
|
|
||||||
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
|
|
||||||
mw = DanglingToolCallMiddleware()
|
|
||||||
result = _tool_msg("web_search:11", "web_search")
|
|
||||||
msgs = [
|
|
||||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
|
||||||
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
|
|
||||||
result,
|
|
||||||
]
|
|
||||||
|
|
||||||
patched = mw._build_patched_messages(msgs)
|
|
||||||
|
|
||||||
assert patched is not None
|
|
||||||
assert patched[1] is result
|
|
||||||
assert patched[1].status == "success"
|
|
||||||
assert isinstance(patched[3], ToolMessage)
|
|
||||||
assert patched[3].tool_call_id == "web_search:11"
|
|
||||||
assert patched[3].status == "error"
|
|
||||||
|
|
||||||
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
||||||
mw = DanglingToolCallMiddleware()
|
mw = DanglingToolCallMiddleware()
|
||||||
msgs = [
|
msgs = [
|
||||||
|
|||||||
@@ -714,110 +714,6 @@ class TestExternalUsageRecords:
|
|||||||
assert j._subagent_tokens == 0
|
assert j._subagent_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
class TestProgressSnapshots:
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_on_llm_end_reports_progress_snapshot(self):
|
|
||||||
snapshots: list[dict] = []
|
|
||||||
|
|
||||||
async def reporter(snapshot: dict) -> None:
|
|
||||||
snapshots.append(snapshot)
|
|
||||||
|
|
||||||
store = MemoryRunEventStore()
|
|
||||||
j = RunJournal(
|
|
||||||
"r1",
|
|
||||||
"t1",
|
|
||||||
store,
|
|
||||||
flush_threshold=100,
|
|
||||||
progress_reporter=reporter,
|
|
||||||
progress_flush_interval=0,
|
|
||||||
)
|
|
||||||
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
||||||
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
|
|
||||||
await j.flush()
|
|
||||||
|
|
||||||
assert snapshots
|
|
||||||
assert snapshots[-1]["total_tokens"] == 15
|
|
||||||
assert snapshots[-1]["llm_call_count"] == 1
|
|
||||||
assert snapshots[-1]["message_count"] == 1
|
|
||||||
assert snapshots[-1]["last_ai_message"] == "Answer"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
|
|
||||||
snapshots: list[dict] = []
|
|
||||||
trailing_seen = asyncio.Event()
|
|
||||||
|
|
||||||
async def reporter(snapshot: dict) -> None:
|
|
||||||
snapshots.append(snapshot)
|
|
||||||
if snapshot["total_tokens"] == 45:
|
|
||||||
trailing_seen.set()
|
|
||||||
|
|
||||||
store = MemoryRunEventStore()
|
|
||||||
j = RunJournal(
|
|
||||||
"r1",
|
|
||||||
"t1",
|
|
||||||
store,
|
|
||||||
flush_threshold=100,
|
|
||||||
progress_reporter=reporter,
|
|
||||||
progress_flush_interval=0.01,
|
|
||||||
)
|
|
||||||
j.on_llm_end(
|
|
||||||
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
|
||||||
run_id=uuid4(),
|
|
||||||
parent_run_id=None,
|
|
||||||
tags=["lead_agent"],
|
|
||||||
)
|
|
||||||
j.on_llm_end(
|
|
||||||
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
|
|
||||||
run_id=uuid4(),
|
|
||||||
parent_run_id=None,
|
|
||||||
tags=["lead_agent"],
|
|
||||||
)
|
|
||||||
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
|
|
||||||
await j.flush()
|
|
||||||
|
|
||||||
assert len(snapshots) >= 2
|
|
||||||
assert snapshots[-1]["total_tokens"] == 45
|
|
||||||
assert snapshots[-1]["llm_call_count"] == 2
|
|
||||||
assert snapshots[-1]["last_ai_message"] == "Second"
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
|
|
||||||
snapshots: list[dict] = []
|
|
||||||
|
|
||||||
async def reporter(snapshot: dict) -> None:
|
|
||||||
snapshots.append(snapshot)
|
|
||||||
|
|
||||||
store = MemoryRunEventStore()
|
|
||||||
j = RunJournal(
|
|
||||||
"r1",
|
|
||||||
"t1",
|
|
||||||
store,
|
|
||||||
flush_threshold=100,
|
|
||||||
progress_reporter=reporter,
|
|
||||||
progress_flush_interval=10.0,
|
|
||||||
)
|
|
||||||
j.on_llm_end(
|
|
||||||
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
|
|
||||||
run_id=uuid4(),
|
|
||||||
parent_run_id=None,
|
|
||||||
tags=["lead_agent"],
|
|
||||||
)
|
|
||||||
await asyncio.sleep(0)
|
|
||||||
assert snapshots[-1]["total_tokens"] == 15
|
|
||||||
j.on_llm_end(
|
|
||||||
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
|
|
||||||
run_id=uuid4(),
|
|
||||||
parent_run_id=None,
|
|
||||||
tags=["lead_agent"],
|
|
||||||
)
|
|
||||||
|
|
||||||
await asyncio.wait_for(j.flush(), timeout=0.2)
|
|
||||||
|
|
||||||
assert snapshots[-1]["total_tokens"] == 15
|
|
||||||
assert snapshots[-1]["llm_call_count"] == 1
|
|
||||||
assert snapshots[-1]["last_ai_message"] == "First"
|
|
||||||
|
|
||||||
|
|
||||||
class TestChatModelStartHumanMessage:
|
class TestChatModelStartHumanMessage:
|
||||||
"""Tests for on_chat_model_start extracting the first human message."""
|
"""Tests for on_chat_model_start extracting the first human message."""
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Tests for RunManager."""
|
"""Tests for RunManager."""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import re
|
import re
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -231,6 +232,81 @@ async def test_create_record_is_not_store_only(manager: RunManager):
|
|||||||
assert record.store_only is False
|
assert record.store_only is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_rolls_back_in_memory_record_on_store_failure():
|
||||||
|
"""create() must fail and hide the run when the initial store write fails."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
store = MemoryRunStore()
|
||||||
|
store.put = AsyncMock(side_effect=RuntimeError("db down"))
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="db down"):
|
||||||
|
await manager.create("thread-1")
|
||||||
|
|
||||||
|
assert manager._runs == {}
|
||||||
|
assert await manager.list_by_thread("thread-1") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_rolls_back_in_memory_record_on_store_cancellation():
|
||||||
|
"""create() must also roll back when cancelled during the initial store write."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
|
||||||
|
async def cancelled_put(run_id, **kwargs):
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
store.put = cancelled_put
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await manager.create("thread-1")
|
||||||
|
|
||||||
|
assert manager._runs == {}
|
||||||
|
assert await manager.list_by_thread("thread-1") == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_does_not_expose_run_until_store_persist_completes():
|
||||||
|
"""Concurrent readers must wait until the new run has been persisted."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
original_put = store.put
|
||||||
|
put_started = asyncio.Event()
|
||||||
|
allow_put = asyncio.Event()
|
||||||
|
|
||||||
|
async def blocking_put(run_id, **kwargs):
|
||||||
|
put_started.set()
|
||||||
|
await allow_put.wait()
|
||||||
|
return await original_put(run_id, **kwargs)
|
||||||
|
|
||||||
|
store.put = blocking_put
|
||||||
|
create_task = asyncio.create_task(manager.create("thread-1"))
|
||||||
|
list_task = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
await put_started.wait()
|
||||||
|
list_task = asyncio.create_task(manager.list_by_thread("thread-1"))
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
assert not list_task.done()
|
||||||
|
|
||||||
|
allow_put.set()
|
||||||
|
record = await create_task
|
||||||
|
runs = await list_task
|
||||||
|
|
||||||
|
assert [run.run_id for run in runs] == [record.run_id]
|
||||||
|
finally:
|
||||||
|
allow_put.set()
|
||||||
|
cleanup_tasks = []
|
||||||
|
for task in (list_task, create_task):
|
||||||
|
if task is None:
|
||||||
|
continue
|
||||||
|
if not task.done():
|
||||||
|
task.cancel()
|
||||||
|
cleanup_tasks.append(task)
|
||||||
|
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_get_prefers_in_memory_record_over_store():
|
async def test_get_prefers_in_memory_record_over_store():
|
||||||
"""In-memory records retain task/control state when store has same run."""
|
"""In-memory records retain task/control state when store has same run."""
|
||||||
@@ -318,6 +394,52 @@ async def test_create_or_reject_interrupt_persists_interrupted_status_to_store()
|
|||||||
assert stored_old["status"] == "interrupted"
|
assert stored_old["status"] == "interrupted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_fails():
|
||||||
|
"""A failed new-run persist must not cancel the existing inflight run."""
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
old = await manager.create("thread-1")
|
||||||
|
await manager.set_status(old.run_id, RunStatus.running)
|
||||||
|
store.put = AsyncMock(side_effect=RuntimeError("db down"))
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="db down"):
|
||||||
|
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
|
||||||
|
|
||||||
|
stored_old = await store.get(old.run_id)
|
||||||
|
assert list(manager._runs) == [old.run_id]
|
||||||
|
assert old.status == RunStatus.running
|
||||||
|
assert old.abort_event.is_set() is False
|
||||||
|
assert stored_old is not None
|
||||||
|
assert stored_old["status"] == "running"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_is_cancelled():
|
||||||
|
"""Cancellation during new-run persist must not cancel the existing run."""
|
||||||
|
store = MemoryRunStore()
|
||||||
|
manager = RunManager(store=store)
|
||||||
|
old = await manager.create("thread-1")
|
||||||
|
await manager.set_status(old.run_id, RunStatus.running)
|
||||||
|
|
||||||
|
async def cancelled_put(run_id, **kwargs):
|
||||||
|
raise asyncio.CancelledError
|
||||||
|
|
||||||
|
store.put = cancelled_put
|
||||||
|
|
||||||
|
with pytest.raises(asyncio.CancelledError):
|
||||||
|
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
|
||||||
|
|
||||||
|
stored_old = await store.get(old.run_id)
|
||||||
|
assert list(manager._runs) == [old.run_id]
|
||||||
|
assert old.status == RunStatus.running
|
||||||
|
assert old.abort_event.is_set() is False
|
||||||
|
assert stored_old is not None
|
||||||
|
assert stored_old["status"] == "running"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
|
async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
|
||||||
"""rollback strategy should persist interrupted status for old runs."""
|
"""rollback strategy should persist interrupted status for old runs."""
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ from sqlalchemy.dialects import postgresql
|
|||||||
|
|
||||||
from deerflow.persistence.run import RunRepository
|
from deerflow.persistence.run import RunRepository
|
||||||
from deerflow.runtime import RunManager, RunStatus
|
from deerflow.runtime import RunManager, RunStatus
|
||||||
from deerflow.runtime.runs.store.base import RunStore
|
|
||||||
|
|
||||||
|
|
||||||
async def _make_repo(tmp_path):
|
async def _make_repo(tmp_path):
|
||||||
@@ -27,42 +26,6 @@ async def _cleanup():
|
|||||||
await close_engine()
|
await close_engine()
|
||||||
|
|
||||||
|
|
||||||
class _CustomRunStoreWithoutProgress(RunStore):
|
|
||||||
async def put(self, *args, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get(self, *args, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def list_by_thread(self, *args, **kwargs):
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def update_status(self, *args, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def delete(self, *args, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def update_model_name(self, *args, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def update_run_completion(self, *args, **kwargs):
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def list_pending(self, *args, **kwargs):
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def aggregate_tokens_by_thread(self, *args, **kwargs):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_update_run_progress_defaults_to_noop_for_custom_store():
|
|
||||||
store = _CustomRunStoreWithoutProgress()
|
|
||||||
|
|
||||||
await store.update_run_progress("r1", total_tokens=1)
|
|
||||||
|
|
||||||
|
|
||||||
class TestRunRepository:
|
class TestRunRepository:
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_put_and_get(self, tmp_path):
|
async def test_put_and_get(self, tmp_path):
|
||||||
@@ -207,69 +170,6 @@ class TestRunRepository:
|
|||||||
assert row["total_tokens"] == 100
|
assert row["total_tokens"] == 100
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_update_run_progress_keeps_status_running(self, tmp_path):
|
|
||||||
repo = await _make_repo(tmp_path)
|
|
||||||
await repo.put("r1", thread_id="t1", status="running")
|
|
||||||
await repo.update_run_progress(
|
|
||||||
"r1",
|
|
||||||
total_input_tokens=40,
|
|
||||||
total_output_tokens=10,
|
|
||||||
total_tokens=50,
|
|
||||||
llm_call_count=1,
|
|
||||||
message_count=2,
|
|
||||||
last_ai_message="partial answer",
|
|
||||||
)
|
|
||||||
row = await repo.get("r1")
|
|
||||||
assert row["status"] == "running"
|
|
||||||
assert row["total_tokens"] == 50
|
|
||||||
assert row["llm_call_count"] == 1
|
|
||||||
assert row["message_count"] == 2
|
|
||||||
assert row["last_ai_message"] == "partial answer"
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
|
|
||||||
repo = await _make_repo(tmp_path)
|
|
||||||
await repo.put("r1", thread_id="t1", status="running")
|
|
||||||
await repo.update_run_progress(
|
|
||||||
"r1",
|
|
||||||
total_input_tokens=40,
|
|
||||||
total_output_tokens=10,
|
|
||||||
total_tokens=50,
|
|
||||||
llm_call_count=1,
|
|
||||||
lead_agent_tokens=30,
|
|
||||||
subagent_tokens=20,
|
|
||||||
message_count=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
|
|
||||||
|
|
||||||
row = await repo.get("r1")
|
|
||||||
assert row["total_input_tokens"] == 40
|
|
||||||
assert row["total_output_tokens"] == 10
|
|
||||||
assert row["total_tokens"] == 60
|
|
||||||
assert row["llm_call_count"] == 1
|
|
||||||
assert row["lead_agent_tokens"] == 30
|
|
||||||
assert row["subagent_tokens"] == 20
|
|
||||||
assert row["message_count"] == 2
|
|
||||||
assert row["last_ai_message"] == "updated"
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
|
|
||||||
repo = await _make_repo(tmp_path)
|
|
||||||
await repo.put("r1", thread_id="t1", status="running")
|
|
||||||
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
|
|
||||||
|
|
||||||
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
|
|
||||||
|
|
||||||
row = await repo.get("r1")
|
|
||||||
assert row["status"] == "success"
|
|
||||||
assert row["total_tokens"] == 100
|
|
||||||
assert row["llm_call_count"] == 1
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
|
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
|
||||||
repo = await _make_repo(tmp_path)
|
repo = await _make_repo(tmp_path)
|
||||||
@@ -325,28 +225,6 @@ class TestRunRepository:
|
|||||||
}
|
}
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
|
||||||
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
|
|
||||||
repo = await _make_repo(tmp_path)
|
|
||||||
await repo.put("success-run", thread_id="t1", status="running")
|
|
||||||
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
|
|
||||||
await repo.put("running-run", thread_id="t1", status="running")
|
|
||||||
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
|
|
||||||
|
|
||||||
without_active = await repo.aggregate_tokens_by_thread("t1")
|
|
||||||
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
|
|
||||||
|
|
||||||
assert without_active["total_tokens"] == 100
|
|
||||||
assert without_active["total_runs"] == 1
|
|
||||||
assert with_active["total_tokens"] == 125
|
|
||||||
assert with_active["total_runs"] == 2
|
|
||||||
assert with_active["by_caller"] == {
|
|
||||||
"lead_agent": 120,
|
|
||||||
"subagent": 5,
|
|
||||||
"middleware": 0,
|
|
||||||
}
|
|
||||||
await _cleanup()
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
async def test_list_by_thread_ordered_desc(self, tmp_path):
|
||||||
"""list_by_thread returns newest first."""
|
"""list_by_thread returns newest first."""
|
||||||
|
|||||||
@@ -53,30 +53,3 @@ def test_thread_token_usage_returns_stable_shape():
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
|
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
|
||||||
|
|
||||||
|
|
||||||
def test_thread_token_usage_can_include_active_runs():
|
|
||||||
run_store = MagicMock()
|
|
||||||
run_store.aggregate_tokens_by_thread = AsyncMock(
|
|
||||||
return_value={
|
|
||||||
"total_tokens": 175,
|
|
||||||
"total_input_tokens": 120,
|
|
||||||
"total_output_tokens": 55,
|
|
||||||
"total_runs": 3,
|
|
||||||
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
|
|
||||||
"by_caller": {
|
|
||||||
"lead_agent": 145,
|
|
||||||
"subagent": 25,
|
|
||||||
"middleware": 5,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
app = _make_app(run_store)
|
|
||||||
|
|
||||||
with TestClient(app) as client:
|
|
||||||
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
|
|
||||||
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.json()["total_tokens"] == 175
|
|
||||||
assert response.json()["total_runs"] == 3
|
|
||||||
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user