Compare commits

..

5 Commits

Author SHA1 Message Date
Willem Jiang 11a362e5e5 Merge branch 'main' into rayhpeng/fix-run-manager-store-atomicity 2026-05-22 21:38:23 +08:00
rayhpeng 85402405ec clarify run creation rollback on cancellation 2026-05-22 18:10:53 +08:00
rayhpeng 43eb643910 fix run manager test cleanup await 2026-05-22 18:01:55 +08:00
rayhpeng f3e3a350ce fix run creation cancellation rollback 2026-05-22 17:58:28 +08:00
rayhpeng 0fae7c9cbb fix runtime run creation persistence atomicity 2026-05-22 11:10:06 +08:00
13 changed files with 189 additions and 565 deletions
+2 -25
View File
@@ -66,14 +66,6 @@ class RunResponse(BaseModel):
multitask_strategy: str = "reject" multitask_strategy: str = "reject"
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
class ThreadTokenUsageModelBreakdown(BaseModel): class ThreadTokenUsageModelBreakdown(BaseModel):
@@ -119,14 +111,6 @@ def _record_to_response(record: RunRecord) -> RunResponse:
multitask_strategy=record.multitask_strategy, multitask_strategy=record.multitask_strategy,
created_at=record.created_at, created_at=record.created_at,
updated_at=record.updated_at, updated_at=record.updated_at,
total_input_tokens=record.total_input_tokens,
total_output_tokens=record.total_output_tokens,
total_tokens=record.total_tokens,
llm_call_count=record.llm_call_count,
lead_agent_tokens=record.lead_agent_tokens,
subagent_tokens=record.subagent_tokens,
middleware_tokens=record.middleware_tokens,
message_count=record.message_count,
) )
@@ -418,15 +402,8 @@ async def list_run_events(
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse) @router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
@require_permission("threads", "read", owner_check=True) @require_permission("threads", "read", owner_check=True)
async def thread_token_usage( async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse:
thread_id: str,
request: Request,
include_active: bool = Query(default=False, description="Include running run progress snapshots"),
) -> ThreadTokenUsageResponse:
"""Thread-level token usage aggregation.""" """Thread-level token usage aggregation."""
run_store = get_run_store(request) run_store = get_run_store(request)
if include_active: agg = await run_store.aggregate_tokens_by_thread(thread_id)
agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True)
else:
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return ThreadTokenUsageResponse(thread_id=thread_id, **agg) return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
@@ -15,7 +15,6 @@ to the end of the message list as before_model + add_messages reducer would do.
import json import json
import logging import logging
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import override from typing import override
@@ -110,10 +109,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
This normalizes model-bound causal order before provider serialization while This normalizes model-bound causal order before provider serialization while
preserving already-valid transcripts unchanged. preserving already-valid transcripts unchanged.
""" """
tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque) tool_messages_by_id: dict[str, ToolMessage] = {}
for msg in messages: for msg in messages:
if isinstance(msg, ToolMessage): if isinstance(msg, ToolMessage):
tool_messages_by_id[msg.tool_call_id].append(msg) tool_messages_by_id.setdefault(msg.tool_call_id, msg)
tool_call_ids: set[str] = set() tool_call_ids: set[str] = set()
for msg in messages: for msg in messages:
@@ -125,6 +124,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
tool_call_ids.add(tc_id) tool_call_ids.add(tc_id)
patched: list = [] patched: list = []
consumed_tool_msg_ids: set[str] = set()
patch_count = 0 patch_count = 0
for msg in messages: for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids: if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for tc in self._message_tool_calls(msg): for tc in self._message_tool_calls(msg):
tc_id = tc.get("id") tc_id = tc.get("id")
if not tc_id: if not tc_id or tc_id in consumed_tool_msg_ids:
continue continue
tool_msg_queue = tool_messages_by_id.get(tc_id) existing_tool_msg = tool_messages_by_id.get(tc_id)
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
if existing_tool_msg is not None: if existing_tool_msg is not None:
patched.append(existing_tool_msg) patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
else: else:
patched.append( patched.append(
ToolMessage( ToolMessage(
@@ -152,6 +152,7 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error", status="error",
) )
) )
consumed_tool_msg_ids.add(tc_id)
patch_count += 1 patch_count += 1
if patched == messages: if patched == messages:
@@ -227,48 +227,9 @@ class RunRepository(RunStore):
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit() await session.commit()
async def update_run_progress( async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Update token usage + convenience fields while a run is still active."""
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
optional_counters = {
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
for key, value in optional_counters.items():
if value is not None:
values[key] = value
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query.""" """Aggregate token usage via a single SQL GROUP BY query."""
statuses = ("success", "error", "running") if include_active else ("success", "error") _completed = RunRow.status.in_(("success", "error"))
_completed = RunRow.status.in_(statuses)
_thread = RunRow.thread_id == thread_id _thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown") model_name = func.coalesce(RunRow.model_name, "unknown")
@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time import time
from collections.abc import Awaitable, Callable, Mapping from collections.abc import Mapping
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import UUID from uuid import UUID
@@ -46,8 +46,6 @@ class RunJournal(BaseCallbackHandler):
*, *,
track_token_usage: bool = True, track_token_usage: bool = True,
flush_threshold: int = 20, flush_threshold: int = 20,
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
progress_flush_interval: float = 5.0,
): ):
super().__init__() super().__init__()
self.run_id = run_id self.run_id = run_id
@@ -55,16 +53,10 @@ class RunJournal(BaseCallbackHandler):
self._store = event_store self._store = event_store
self._track_tokens = track_token_usage self._track_tokens = track_token_usage
self._flush_threshold = flush_threshold self._flush_threshold = flush_threshold
self._progress_reporter = progress_reporter
self._progress_flush_interval = progress_flush_interval
# Write buffer # Write buffer
self._buffer: list[dict] = [] self._buffer: list[dict] = []
self._pending_flush_tasks: set[asyncio.Task[None]] = set() self._pending_flush_tasks: set[asyncio.Task[None]] = set()
self._pending_progress_task: asyncio.Task[None] | None = None
self._pending_progress_delayed = False
self._progress_dirty = False
self._last_progress_flush = 0.0
# Token accumulators # Token accumulators
self._total_input_tokens = 0 self._total_input_tokens = 0
@@ -302,8 +294,6 @@ class RunJournal(BaseCallbackHandler):
else: else:
self._lead_agent_tokens += total_tk self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
if messages: if messages:
self._counted_message_llm_run_ids.add(str(run_id)) self._counted_message_llm_run_ids.add(str(run_id))
@@ -455,8 +445,6 @@ class RunJournal(BaseCallbackHandler):
else: else:
self._lead_agent_tokens += total_tk self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
def set_first_human_message(self, content: str) -> None: def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields.""" """Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None self._first_human_msg = content[:2000] if content else None
@@ -486,14 +474,6 @@ class RunJournal(BaseCallbackHandler):
"""Force flush remaining buffer. Called in worker's finally block.""" """Force flush remaining buffer. Called in worker's finally block."""
if self._pending_flush_tasks: if self._pending_flush_tasks:
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True) await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
while self._pending_progress_task is not None and not self._pending_progress_task.done():
if self._pending_progress_delayed:
self._pending_progress_task.cancel()
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
self._progress_dirty = False
self._pending_progress_delayed = False
break
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
while self._buffer: while self._buffer:
batch = self._buffer[: self._flush_threshold] batch = self._buffer[: self._flush_threshold]
@@ -504,57 +484,6 @@ class RunJournal(BaseCallbackHandler):
self._buffer = batch + self._buffer self._buffer = batch + self._buffer
raise raise
def _schedule_progress_flush(self) -> None:
"""Best-effort throttled progress snapshot for active run visibility."""
if self._progress_reporter is None:
return
now = time.monotonic()
elapsed = now - self._last_progress_flush
if elapsed < self._progress_flush_interval:
self._progress_dirty = True
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
return
if self._pending_progress_task is not None and not self._pending_progress_task.done():
self._progress_dirty = True
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
self._progress_dirty = False
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
def _schedule_delayed_progress_flush(self, delay: float) -> None:
if self._pending_progress_task is not None and not self._pending_progress_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
delay = max(0.0, delay)
self._pending_progress_delayed = delay > 0
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
if self._progress_reporter is None:
return
if delay > 0:
self._pending_progress_delayed = True
await asyncio.sleep(delay)
self._pending_progress_delayed = False
dirty_before_write = self._progress_dirty
self._progress_dirty = False
snapshot_to_write = snapshot or self.get_completion_data()
try:
await self._progress_reporter(snapshot_to_write)
self._last_progress_flush = time.monotonic()
except Exception:
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
if dirty_before_write or self._progress_dirty:
self._progress_dirty = False
self._pending_progress_task = None
self._schedule_delayed_progress_flush(self._progress_flush_interval)
def get_completion_data(self) -> dict: def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion.""" """Return accumulated token and message data for run completion."""
return { return {
@@ -38,16 +38,6 @@ class RunRecord:
error: str | None = None error: str | None = None
model_name: str | None = None model_name: str | None = None
store_only: bool = False store_only: bool = False
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
last_ai_message: str | None = None
first_human_message: str | None = None
class RunManager: class RunManager:
@@ -63,24 +53,27 @@ class RunManager:
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
self._store = store self._store = store
async def _persist_to_store(self, record: RunRecord) -> None: async def _persist_new_run_to_store(self, record: RunRecord) -> None:
"""Best-effort persist run record to backing store.""" """Persist a newly created run record to the backing store.
Initial run creation is part of the run visibility boundary: callers
should not observe a run in memory unless its backing store row exists.
Unlike follow-up status/model updates, failures are propagated so the
caller can treat creation as failed.
"""
if self._store is None: if self._store is None:
return return
try: await self._store.put(
await self._store.put( record.run_id,
record.run_id, thread_id=record.thread_id,
thread_id=record.thread_id, assistant_id=record.assistant_id,
assistant_id=record.assistant_id, status=record.status.value,
status=record.status.value, multitask_strategy=record.multitask_strategy,
multitask_strategy=record.multitask_strategy, metadata=record.metadata or {},
metadata=record.metadata or {}, kwargs=record.kwargs or {},
kwargs=record.kwargs or {}, created_at=record.created_at,
created_at=record.created_at, model_name=record.model_name,
model_name=record.model_name, )
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None: async def _persist_status(self, run_id: str, status: RunStatus, *, error: str | None = None) -> None:
"""Best-effort persist a status transition to the backing store.""" """Best-effort persist a status transition to the backing store."""
@@ -112,53 +105,16 @@ class RunManager:
error=row.get("error"), error=row.get("error"),
model_name=row.get("model_name"), model_name=row.get("model_name"),
store_only=True, store_only=True,
total_input_tokens=row.get("total_input_tokens") or 0,
total_output_tokens=row.get("total_output_tokens") or 0,
total_tokens=row.get("total_tokens") or 0,
llm_call_count=row.get("llm_call_count") or 0,
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
subagent_tokens=row.get("subagent_tokens") or 0,
middleware_tokens=row.get("middleware_tokens") or 0,
message_count=row.get("message_count") or 0,
last_ai_message=row.get("last_ai_message"),
first_human_message=row.get("first_human_message"),
) )
async def update_run_completion(self, run_id: str, **kwargs) -> None: async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store.""" """Persist token usage and completion data to the backing store."""
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
for key, value in kwargs.items():
if key == "status":
continue
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if self._store is not None: if self._store is not None:
try: try:
await self._store.update_run_completion(run_id, **kwargs) await self._store.update_run_completion(run_id, **kwargs)
except Exception: except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def update_run_progress(self, run_id: str, **kwargs) -> None:
"""Persist a running token/message snapshot without changing status."""
should_persist = True
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
should_persist = record.status == RunStatus.running
if record is not None and should_persist:
for key, value in kwargs.items():
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if should_persist and self._store is not None:
try:
await self._store.update_run_progress(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
async def create( async def create(
self, self,
thread_id: str, thread_id: str,
@@ -186,7 +142,16 @@ class RunManager:
) )
async with self._lock: async with self._lock:
self._runs[run_id] = record self._runs[run_id] = record
await self._persist_to_store(record) persisted = False
try:
await self._persist_new_run_to_store(record)
persisted = True
except Exception:
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
raise
finally:
if not persisted:
self._runs.pop(run_id, None)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
@@ -364,16 +329,8 @@ class RunManager:
raise ConflictError(f"Thread {thread_id} already has an active run") raise ConflictError(f"Thread {thread_id} already has an active run")
if multitask_strategy in ("interrupt", "rollback") and inflight: if multitask_strategy in ("interrupt", "rollback") and inflight:
for r in inflight:
r.abort_action = multitask_strategy
r.abort_event.set()
if r.task is not None and not r.task.done():
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
interrupted_run_ids.append(r.run_id)
logger.info( logger.info(
"Cancelled %d inflight run(s) on thread %s (strategy=%s)", "Preparing to cancel %d inflight run(s) on thread %s (strategy=%s)",
len(inflight), len(inflight),
thread_id, thread_id,
multitask_strategy, multitask_strategy,
@@ -393,10 +350,29 @@ class RunManager:
model_name=model_name, model_name=model_name,
) )
self._runs[run_id] = record self._runs[run_id] = record
persisted = False
try:
await self._persist_new_run_to_store(record)
persisted = True
except Exception:
logger.warning("Failed to persist run %s; rolled back in-memory record", run_id, exc_info=True)
raise
finally:
if not persisted:
self._runs.pop(run_id, None)
if multitask_strategy in ("interrupt", "rollback") and inflight:
for r in inflight:
r.abort_action = multitask_strategy
r.abort_event.set()
if r.task is not None and not r.task.done():
r.task.cancel()
r.status = RunStatus.interrupted
r.updated_at = now
interrupted_run_ids.append(r.run_id)
for interrupted_run_id in interrupted_run_ids: for interrupted_run_id in interrupted_run_ids:
await self._persist_status(interrupted_run_id, RunStatus.interrupted) await self._persist_status(interrupted_run_id, RunStatus.interrupted)
await self._persist_to_store(record)
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id) logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
return record return record
@@ -95,30 +95,12 @@ class RunStore(abc.ABC):
) -> None: ) -> None:
pass pass
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Persist a best-effort running snapshot without changing run status."""
return None
@abc.abstractmethod @abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread. """Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens, Returns a dict with keys: total_tokens, total_input_tokens,
@@ -82,22 +82,14 @@ class MemoryRunStore(RunStore):
self._runs[run_id][key] = value self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running":
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None): async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat() now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now] results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"]) results.sort(key=lambda r: r["created_at"])
return results return results
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]: async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
statuses = ("success", "error", "running") if include_active else ("success", "error") completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")]
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
by_model: dict[str, dict] = {} by_model: dict[str, dict] = {}
for r in completed: for r in completed:
model = r.get("model_name") or "unknown" model = r.get("model_name") or "unknown"
@@ -153,6 +153,8 @@ async def run_agent(
journal = None journal = None
journal = None
# Track whether "events" was requested but skipped # Track whether "events" was requested but skipped
if "events" in requested_modes: if "events" in requested_modes:
logger.info( logger.info(
@@ -175,7 +177,6 @@ async def run_agent(
thread_id=thread_id, thread_id=thread_id,
event_store=event_store, event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True), track_token_usage=getattr(run_events_config, "track_token_usage", True),
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
) )
# 1. Mark running # 1. Mark running
@@ -218,70 +218,6 @@ class TestBuildPatchedMessagesPatching:
assert mw._build_patched_messages(msgs) is None assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
mw = DanglingToolCallMiddleware()
msgs = [
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:11"),
_tc("web_search", "web_search:12"),
_tc("web_search", "web_search:13"),
]
),
_tool_msg("web_search:11", "web_search"),
_tool_msg("web_search:12", "web_search"),
_tool_msg("web_search:13", "web_search"),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:9"),
_tc("web_search", "web_search:10"),
_tc("web_search", "web_search:11"),
]
),
_tool_msg("web_search:9", "web_search"),
_tool_msg("web_search:10", "web_search"),
_tool_msg("web_search:11", "web_search"),
]
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_tool_msg("web_search:11", "web_search"),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "web_search:11"
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
result = _tool_msg("web_search:11", "web_search")
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
result,
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert patched[1] is result
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self): def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware() mw = DanglingToolCallMiddleware()
msgs = [ msgs = [
-104
View File
@@ -714,110 +714,6 @@ class TestExternalUsageRecords:
assert j._subagent_tokens == 0 assert j._subagent_tokens == 0
class TestProgressSnapshots:
@pytest.mark.anyio
async def test_on_llm_end_reports_progress_snapshot(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0,
)
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
await j.flush()
assert snapshots
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["message_count"] == 1
assert snapshots[-1]["last_ai_message"] == "Answer"
@pytest.mark.anyio
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
snapshots: list[dict] = []
trailing_seen = asyncio.Event()
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
if snapshot["total_tokens"] == 45:
trailing_seen.set()
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0.01,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
await j.flush()
assert len(snapshots) >= 2
assert snapshots[-1]["total_tokens"] == 45
assert snapshots[-1]["llm_call_count"] == 2
assert snapshots[-1]["last_ai_message"] == "Second"
@pytest.mark.anyio
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=10.0,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.sleep(0)
assert snapshots[-1]["total_tokens"] == 15
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(j.flush(), timeout=0.2)
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["last_ai_message"] == "First"
class TestChatModelStartHumanMessage: class TestChatModelStartHumanMessage:
"""Tests for on_chat_model_start extracting the first human message.""" """Tests for on_chat_model_start extracting the first human message."""
+122
View File
@@ -1,5 +1,6 @@
"""Tests for RunManager.""" """Tests for RunManager."""
import asyncio
import re import re
import pytest import pytest
@@ -231,6 +232,81 @@ async def test_create_record_is_not_store_only(manager: RunManager):
assert record.store_only is False assert record.store_only is False
@pytest.mark.anyio
async def test_create_rolls_back_in_memory_record_on_store_failure():
"""create() must fail and hide the run when the initial store write fails."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
store.put = AsyncMock(side_effect=RuntimeError("db down"))
manager = RunManager(store=store)
with pytest.raises(RuntimeError, match="db down"):
await manager.create("thread-1")
assert manager._runs == {}
assert await manager.list_by_thread("thread-1") == []
@pytest.mark.anyio
async def test_create_rolls_back_in_memory_record_on_store_cancellation():
"""create() must also roll back when cancelled during the initial store write."""
store = MemoryRunStore()
async def cancelled_put(run_id, **kwargs):
raise asyncio.CancelledError
store.put = cancelled_put
manager = RunManager(store=store)
with pytest.raises(asyncio.CancelledError):
await manager.create("thread-1")
assert manager._runs == {}
assert await manager.list_by_thread("thread-1") == []
@pytest.mark.anyio
async def test_create_does_not_expose_run_until_store_persist_completes():
"""Concurrent readers must wait until the new run has been persisted."""
store = MemoryRunStore()
manager = RunManager(store=store)
original_put = store.put
put_started = asyncio.Event()
allow_put = asyncio.Event()
async def blocking_put(run_id, **kwargs):
put_started.set()
await allow_put.wait()
return await original_put(run_id, **kwargs)
store.put = blocking_put
create_task = asyncio.create_task(manager.create("thread-1"))
list_task = None
try:
await put_started.wait()
list_task = asyncio.create_task(manager.list_by_thread("thread-1"))
await asyncio.sleep(0)
assert not list_task.done()
allow_put.set()
record = await create_task
runs = await list_task
assert [run.run_id for run in runs] == [record.run_id]
finally:
allow_put.set()
cleanup_tasks = []
for task in (list_task, create_task):
if task is None:
continue
if not task.done():
task.cancel()
cleanup_tasks.append(task)
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
@pytest.mark.anyio @pytest.mark.anyio
async def test_get_prefers_in_memory_record_over_store(): async def test_get_prefers_in_memory_record_over_store():
"""In-memory records retain task/control state when store has same run.""" """In-memory records retain task/control state when store has same run."""
@@ -318,6 +394,52 @@ async def test_create_or_reject_interrupt_persists_interrupted_status_to_store()
assert stored_old["status"] == "interrupted" assert stored_old["status"] == "interrupted"
@pytest.mark.anyio
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_fails():
"""A failed new-run persist must not cancel the existing inflight run."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
store.put = AsyncMock(side_effect=RuntimeError("db down"))
with pytest.raises(RuntimeError, match="db down"):
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert list(manager._runs) == [old.run_id]
assert old.status == RunStatus.running
assert old.abort_event.is_set() is False
assert stored_old is not None
assert stored_old["status"] == "running"
@pytest.mark.anyio
async def test_create_or_reject_does_not_interrupt_old_run_when_new_run_store_write_is_cancelled():
"""Cancellation during new-run persist must not cancel the existing run."""
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
async def cancelled_put(run_id, **kwargs):
raise asyncio.CancelledError
store.put = cancelled_put
with pytest.raises(asyncio.CancelledError):
await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert list(manager._runs) == [old.run_id]
assert old.status == RunStatus.running
assert old.abort_event.is_set() is False
assert stored_old is not None
assert stored_old["status"] == "running"
@pytest.mark.anyio @pytest.mark.anyio
async def test_create_or_reject_rollback_persists_interrupted_status_to_store(): async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
"""rollback strategy should persist interrupted status for old runs.""" """rollback strategy should persist interrupted status for old runs."""
-122
View File
@@ -10,7 +10,6 @@ from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository from deerflow.persistence.run import RunRepository
from deerflow.runtime import RunManager, RunStatus from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.store.base import RunStore
async def _make_repo(tmp_path): async def _make_repo(tmp_path):
@@ -27,42 +26,6 @@ async def _cleanup():
await close_engine() await close_engine()
class _CustomRunStoreWithoutProgress(RunStore):
async def put(self, *args, **kwargs):
return None
async def get(self, *args, **kwargs):
return None
async def list_by_thread(self, *args, **kwargs):
return []
async def update_status(self, *args, **kwargs):
return None
async def delete(self, *args, **kwargs):
return None
async def update_model_name(self, *args, **kwargs):
return None
async def update_run_completion(self, *args, **kwargs):
return None
async def list_pending(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {}
@pytest.mark.anyio
async def test_update_run_progress_defaults_to_noop_for_custom_store():
store = _CustomRunStoreWithoutProgress()
await store.update_run_progress("r1", total_tokens=1)
class TestRunRepository: class TestRunRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_and_get(self, tmp_path): async def test_put_and_get(self, tmp_path):
@@ -207,69 +170,6 @@ class TestRunRepository:
assert row["total_tokens"] == 100 assert row["total_tokens"] == 100
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_keeps_status_running(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
message_count=2,
last_ai_message="partial answer",
)
row = await repo.get("r1")
assert row["status"] == "running"
assert row["total_tokens"] == 50
assert row["llm_call_count"] == 1
assert row["message_count"] == 2
assert row["last_ai_message"] == "partial answer"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
lead_agent_tokens=30,
subagent_tokens=20,
message_count=2,
)
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
row = await repo.get("r1")
assert row["total_input_tokens"] == 40
assert row["total_output_tokens"] == 10
assert row["total_tokens"] == 60
assert row["llm_call_count"] == 1
assert row["lead_agent_tokens"] == 30
assert row["subagent_tokens"] == 20
assert row["message_count"] == 2
assert row["last_ai_message"] == "updated"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
row = await repo.get("r1")
assert row["status"] == "success"
assert row["total_tokens"] == 100
assert row["llm_call_count"] == 1
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path): async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -325,28 +225,6 @@ class TestRunRepository:
} }
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
without_active = await repo.aggregate_tokens_by_thread("t1")
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
assert without_active["total_tokens"] == 100
assert without_active["total_runs"] == 1
assert with_active["total_tokens"] == 125
assert with_active["total_runs"] == 2
assert with_active["by_caller"] == {
"lead_agent": 120,
"subagent": 5,
"middleware": 0,
}
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path): async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first.""" """list_by_thread returns newest first."""
-27
View File
@@ -53,30 +53,3 @@ def test_thread_token_usage_returns_stable_shape():
}, },
} }
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1") run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
def test_thread_token_usage_can_include_active_runs():
run_store = MagicMock()
run_store.aggregate_tokens_by_thread = AsyncMock(
return_value={
"total_tokens": 175,
"total_input_tokens": 120,
"total_output_tokens": 55,
"total_runs": 3,
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
"by_caller": {
"lead_agent": 145,
"subagent": 25,
"middleware": 5,
},
},
)
app = _make_app(run_store)
with TestClient(app) as client:
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
assert response.status_code == 200
assert response.json()["total_tokens"] == 175
assert response.json()["total_runs"] == 3
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)