Compare commits

...

4 Commits

Author SHA1 Message Date
Willem Jiang cfd9c61b9a Merge branch 'main' into fix-3127 2026-05-22 21:56:04 +08:00
Nan Gao f0bae28636 fix(middleware): handle repeated tool call ids (#3143)
* fix(middleware): handle repeated tool call ids

* add tests

* refactor(middleware): rely on tool result queues
2026-05-22 21:44:05 +08:00
Lawrance_YXLiao 2eeb597985 fix(runs): expose active progress counters (#3148)
* fix(runs): expose active progress counters

* fix(runs): avoid delayed progress flush on completion

* fix(runs): tighten progress snapshot semantics

* fix(runs): preserve omitted progress fields

* chore(runs): remove duplicate journal initialization
2026-05-22 21:42:14 +08:00
Willem Jiang 4731605d99 fix(sandbox): add group/other read permissions to uploaded files for Docker sandbox (#3127)
When using AIO sandbox with LocalContainerBackend, uploaded files are
  created with 0o600 (owner-only) permissions by the gateway process
  running as root. The sandbox process inside the Docker container runs
  as a non-root user and cannot read these bind-mounted files, causing
  a "Permission denied" error on read_file.

  Add `needs_upload_permission_adjustment` attribute to SandboxProvider
  (default True) to indicate that uploaded files need chmod adjustment.
  LocalSandboxProvider opts out (same user). A new `_make_file_sandbox_readable`
  function adds S_IRGRP | S_IROTH bits after files are written, changing
  permissions from 0o600 to 0o644 so the sandbox can read the uploads.

  fixes #3127
2026-05-21 17:51:56 +08:00
16 changed files with 624 additions and 17 deletions
+25 -2
View File
@@ -66,6 +66,14 @@ class RunResponse(BaseModel):
multitask_strategy: str = "reject" multitask_strategy: str = "reject"
created_at: str = "" created_at: str = ""
updated_at: str = "" updated_at: str = ""
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
class ThreadTokenUsageModelBreakdown(BaseModel): class ThreadTokenUsageModelBreakdown(BaseModel):
@@ -111,6 +119,14 @@ def _record_to_response(record: RunRecord) -> RunResponse:
multitask_strategy=record.multitask_strategy, multitask_strategy=record.multitask_strategy,
created_at=record.created_at, created_at=record.created_at,
updated_at=record.updated_at, updated_at=record.updated_at,
total_input_tokens=record.total_input_tokens,
total_output_tokens=record.total_output_tokens,
total_tokens=record.total_tokens,
llm_call_count=record.llm_call_count,
lead_agent_tokens=record.lead_agent_tokens,
subagent_tokens=record.subagent_tokens,
middleware_tokens=record.middleware_tokens,
message_count=record.message_count,
) )
@@ -402,8 +418,15 @@ async def list_run_events(
@router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse) @router.get("/{thread_id}/token-usage", response_model=ThreadTokenUsageResponse)
@require_permission("threads", "read", owner_check=True) @require_permission("threads", "read", owner_check=True)
async def thread_token_usage(thread_id: str, request: Request) -> ThreadTokenUsageResponse: async def thread_token_usage(
thread_id: str,
request: Request,
include_active: bool = Query(default=False, description="Include running run progress snapshots"),
) -> ThreadTokenUsageResponse:
"""Thread-level token usage aggregation.""" """Thread-level token usage aggregation."""
run_store = get_run_store(request) run_store = get_run_store(request)
agg = await run_store.aggregate_tokens_by_thread(thread_id) if include_active:
agg = await run_store.aggregate_tokens_by_thread(thread_id, include_active=True)
else:
agg = await run_store.aggregate_tokens_by_thread(thread_id)
return ThreadTokenUsageResponse(thread_id=thread_id, **agg) return ThreadTokenUsageResponse(thread_id=thread_id, **agg)
+28
View File
@@ -74,6 +74,25 @@ def _make_file_sandbox_writable(file_path: os.PathLike[str] | str) -> None:
os.chmod(file_path, writable_mode, **chmod_kwargs) os.chmod(file_path, writable_mode, **chmod_kwargs)
def _make_file_sandbox_readable(file_path: os.PathLike[str] | str) -> None:
"""Ensure uploaded files are readable by the sandbox process.
For Docker sandboxes (AIO), the gateway writes files as root with 0o600
permissions, then bind-mounts the host directory into the container. The
sandbox process inside the container runs as a non-root user and cannot
read those files without group/other read bits. This function adds
``S_IRGRP | S_IROTH`` so the sandbox can read the uploaded content.
"""
file_stat = os.lstat(file_path)
if stat.S_ISLNK(file_stat.st_mode):
logger.warning("Skipping sandbox chmod for symlinked upload path: %s", file_path)
return
readable_mode = stat.S_IMODE(file_stat.st_mode) | stat.S_IRGRP | stat.S_IROTH
chmod_kwargs = {"follow_symlinks": False} if os.chmod in os.supports_follow_symlinks else {}
os.chmod(file_path, readable_mode, **chmod_kwargs)
def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool: def _uses_thread_data_mounts(sandbox_provider: SandboxProvider) -> bool:
return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False)) return bool(getattr(sandbox_provider, "uses_thread_data_mounts", False))
@@ -276,6 +295,15 @@ async def upload_files(
_cleanup_uploaded_paths(written_paths) _cleanup_uploaded_paths(written_paths)
raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}") raise HTTPException(status_code=500, detail=f"Failed to upload {file.filename}: {str(e)}")
# When the sandbox uses bind-mounted thread data directories (e.g. AIO with
# LocalContainerBackend), uploaded files are visible inside the container but
# retain the 0o600 permissions set by the gateway. The sandbox process runs
# as a different user and cannot read them. Adjust permissions to add
# group/other read bits so the sandbox can access the files.
if not sync_to_sandbox and getattr(sandbox_provider, "needs_upload_permission_adjustment", True):
for file_path in written_paths:
_make_file_sandbox_readable(file_path)
if sync_to_sandbox: if sync_to_sandbox:
for file_path, virtual_path in sandbox_sync_targets: for file_path, virtual_path in sandbox_sync_targets:
_make_file_sandbox_writable(file_path) _make_file_sandbox_writable(file_path)
@@ -15,6 +15,7 @@ to the end of the message list as before_model + add_messages reducer would do.
import json import json
import logging import logging
from collections import defaultdict, deque
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable
from typing import override from typing import override
@@ -109,10 +110,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
This normalizes model-bound causal order before provider serialization while This normalizes model-bound causal order before provider serialization while
preserving already-valid transcripts unchanged. preserving already-valid transcripts unchanged.
""" """
tool_messages_by_id: dict[str, ToolMessage] = {} tool_messages_by_id: dict[str, deque[ToolMessage]] = defaultdict(deque)
for msg in messages: for msg in messages:
if isinstance(msg, ToolMessage): if isinstance(msg, ToolMessage):
tool_messages_by_id.setdefault(msg.tool_call_id, msg) tool_messages_by_id[msg.tool_call_id].append(msg)
tool_call_ids: set[str] = set() tool_call_ids: set[str] = set()
for msg in messages: for msg in messages:
@@ -124,7 +125,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
tool_call_ids.add(tc_id) tool_call_ids.add(tc_id)
patched: list = [] patched: list = []
consumed_tool_msg_ids: set[str] = set()
patch_count = 0 patch_count = 0
for msg in messages: for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids: if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
@@ -136,13 +136,13 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
for tc in self._message_tool_calls(msg): for tc in self._message_tool_calls(msg):
tc_id = tc.get("id") tc_id = tc.get("id")
if not tc_id or tc_id in consumed_tool_msg_ids: if not tc_id:
continue continue
existing_tool_msg = tool_messages_by_id.get(tc_id) tool_msg_queue = tool_messages_by_id.get(tc_id)
existing_tool_msg = tool_msg_queue.popleft() if tool_msg_queue else None
if existing_tool_msg is not None: if existing_tool_msg is not None:
patched.append(existing_tool_msg) patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
else: else:
patched.append( patched.append(
ToolMessage( ToolMessage(
@@ -152,7 +152,6 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error", status="error",
) )
) )
consumed_tool_msg_ids.add(tc_id)
patch_count += 1 patch_count += 1
if patched == messages: if patched == messages:
@@ -227,9 +227,48 @@ class RunRepository(RunStore):
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values)) await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit() await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Update token usage + convenience fields while a run is still active."""
values: dict[str, Any] = {"updated_at": datetime.now(UTC)}
optional_counters = {
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
for key, value in optional_counters.items():
if value is not None:
values[key] = value
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id, RunRow.status == "running").values(**values))
await session.commit()
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage via a single SQL GROUP BY query.""" """Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error")) statuses = ("success", "error", "running") if include_active else ("success", "error")
_completed = RunRow.status.in_(statuses)
_thread = RunRow.thread_id == thread_id _thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown") model_name = func.coalesce(RunRow.model_name, "unknown")
@@ -20,7 +20,7 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
import time import time
from collections.abc import Mapping from collections.abc import Awaitable, Callable, Mapping
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from uuid import UUID from uuid import UUID
@@ -46,6 +46,8 @@ class RunJournal(BaseCallbackHandler):
*, *,
track_token_usage: bool = True, track_token_usage: bool = True,
flush_threshold: int = 20, flush_threshold: int = 20,
progress_reporter: Callable[[dict], Awaitable[None]] | None = None,
progress_flush_interval: float = 5.0,
): ):
super().__init__() super().__init__()
self.run_id = run_id self.run_id = run_id
@@ -53,10 +55,16 @@ class RunJournal(BaseCallbackHandler):
self._store = event_store self._store = event_store
self._track_tokens = track_token_usage self._track_tokens = track_token_usage
self._flush_threshold = flush_threshold self._flush_threshold = flush_threshold
self._progress_reporter = progress_reporter
self._progress_flush_interval = progress_flush_interval
# Write buffer # Write buffer
self._buffer: list[dict] = [] self._buffer: list[dict] = []
self._pending_flush_tasks: set[asyncio.Task[None]] = set() self._pending_flush_tasks: set[asyncio.Task[None]] = set()
self._pending_progress_task: asyncio.Task[None] | None = None
self._pending_progress_delayed = False
self._progress_dirty = False
self._last_progress_flush = 0.0
# Token accumulators # Token accumulators
self._total_input_tokens = 0 self._total_input_tokens = 0
@@ -294,6 +302,8 @@ class RunJournal(BaseCallbackHandler):
else: else:
self._lead_agent_tokens += total_tk self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
if messages: if messages:
self._counted_message_llm_run_ids.add(str(run_id)) self._counted_message_llm_run_ids.add(str(run_id))
@@ -445,6 +455,8 @@ class RunJournal(BaseCallbackHandler):
else: else:
self._lead_agent_tokens += total_tk self._lead_agent_tokens += total_tk
self._schedule_progress_flush()
def set_first_human_message(self, content: str) -> None: def set_first_human_message(self, content: str) -> None:
"""Record the first human message for convenience fields.""" """Record the first human message for convenience fields."""
self._first_human_msg = content[:2000] if content else None self._first_human_msg = content[:2000] if content else None
@@ -474,6 +486,14 @@ class RunJournal(BaseCallbackHandler):
"""Force flush remaining buffer. Called in worker's finally block.""" """Force flush remaining buffer. Called in worker's finally block."""
if self._pending_flush_tasks: if self._pending_flush_tasks:
await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True) await asyncio.gather(*tuple(self._pending_flush_tasks), return_exceptions=True)
while self._pending_progress_task is not None and not self._pending_progress_task.done():
if self._pending_progress_delayed:
self._pending_progress_task.cancel()
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
self._progress_dirty = False
self._pending_progress_delayed = False
break
await asyncio.gather(self._pending_progress_task, return_exceptions=True)
while self._buffer: while self._buffer:
batch = self._buffer[: self._flush_threshold] batch = self._buffer[: self._flush_threshold]
@@ -484,6 +504,57 @@ class RunJournal(BaseCallbackHandler):
self._buffer = batch + self._buffer self._buffer = batch + self._buffer
raise raise
def _schedule_progress_flush(self) -> None:
"""Best-effort throttled progress snapshot for active run visibility."""
if self._progress_reporter is None:
return
now = time.monotonic()
elapsed = now - self._last_progress_flush
if elapsed < self._progress_flush_interval:
self._progress_dirty = True
self._schedule_delayed_progress_flush(self._progress_flush_interval - elapsed)
return
if self._pending_progress_task is not None and not self._pending_progress_task.done():
self._progress_dirty = True
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
self._progress_dirty = False
self._pending_progress_task = loop.create_task(self._flush_progress_async(snapshot=self.get_completion_data()))
def _schedule_delayed_progress_flush(self, delay: float) -> None:
if self._pending_progress_task is not None and not self._pending_progress_task.done():
return
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
delay = max(0.0, delay)
self._pending_progress_delayed = delay > 0
self._pending_progress_task = loop.create_task(self._flush_progress_async(delay=delay))
async def _flush_progress_async(self, *, snapshot: dict | None = None, delay: float = 0.0) -> None:
if self._progress_reporter is None:
return
if delay > 0:
self._pending_progress_delayed = True
await asyncio.sleep(delay)
self._pending_progress_delayed = False
dirty_before_write = self._progress_dirty
self._progress_dirty = False
snapshot_to_write = snapshot or self.get_completion_data()
try:
await self._progress_reporter(snapshot_to_write)
self._last_progress_flush = time.monotonic()
except Exception:
logger.warning("Failed to persist progress snapshot for run %s", self.run_id, exc_info=True)
if dirty_before_write or self._progress_dirty:
self._progress_dirty = False
self._pending_progress_task = None
self._schedule_delayed_progress_flush(self._progress_flush_interval)
def get_completion_data(self) -> dict: def get_completion_data(self) -> dict:
"""Return accumulated token and message data for run completion.""" """Return accumulated token and message data for run completion."""
return { return {
@@ -38,6 +38,16 @@ class RunRecord:
error: str | None = None error: str | None = None
model_name: str | None = None model_name: str | None = None
store_only: bool = False store_only: bool = False
total_input_tokens: int = 0
total_output_tokens: int = 0
total_tokens: int = 0
llm_call_count: int = 0
lead_agent_tokens: int = 0
subagent_tokens: int = 0
middleware_tokens: int = 0
message_count: int = 0
last_ai_message: str | None = None
first_human_message: str | None = None
class RunManager: class RunManager:
@@ -102,16 +112,53 @@ class RunManager:
error=row.get("error"), error=row.get("error"),
model_name=row.get("model_name"), model_name=row.get("model_name"),
store_only=True, store_only=True,
total_input_tokens=row.get("total_input_tokens") or 0,
total_output_tokens=row.get("total_output_tokens") or 0,
total_tokens=row.get("total_tokens") or 0,
llm_call_count=row.get("llm_call_count") or 0,
lead_agent_tokens=row.get("lead_agent_tokens") or 0,
subagent_tokens=row.get("subagent_tokens") or 0,
middleware_tokens=row.get("middleware_tokens") or 0,
message_count=row.get("message_count") or 0,
last_ai_message=row.get("last_ai_message"),
first_human_message=row.get("first_human_message"),
) )
async def update_run_completion(self, run_id: str, **kwargs) -> None: async def update_run_completion(self, run_id: str, **kwargs) -> None:
"""Persist token usage and completion data to the backing store.""" """Persist token usage and completion data to the backing store."""
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
for key, value in kwargs.items():
if key == "status":
continue
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if self._store is not None: if self._store is not None:
try: try:
await self._store.update_run_completion(run_id, **kwargs) await self._store.update_run_completion(run_id, **kwargs)
except Exception: except Exception:
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True) logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
async def update_run_progress(self, run_id: str, **kwargs) -> None:
"""Persist a running token/message snapshot without changing status."""
should_persist = True
async with self._lock:
record = self._runs.get(run_id)
if record is not None:
should_persist = record.status == RunStatus.running
if record is not None and should_persist:
for key, value in kwargs.items():
if hasattr(record, key) and value is not None:
setattr(record, key, value)
record.updated_at = _now_iso()
if should_persist and self._store is not None:
try:
await self._store.update_run_progress(run_id, **kwargs)
except Exception:
logger.warning("Failed to persist run progress for %s", run_id, exc_info=True)
async def create( async def create(
self, self,
thread_id: str, thread_id: str,
@@ -95,12 +95,30 @@ class RunStore(abc.ABC):
) -> None: ) -> None:
pass pass
async def update_run_progress(
self,
run_id: str,
*,
total_input_tokens: int | None = None,
total_output_tokens: int | None = None,
total_tokens: int | None = None,
llm_call_count: int | None = None,
lead_agent_tokens: int | None = None,
subagent_tokens: int | None = None,
middleware_tokens: int | None = None,
message_count: int | None = None,
last_ai_message: str | None = None,
first_human_message: str | None = None,
) -> None:
"""Persist a best-effort running snapshot without changing run status."""
return None
@abc.abstractmethod @abc.abstractmethod
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass pass
@abc.abstractmethod @abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread. """Aggregate token usage for completed runs in a thread.
Returns a dict with keys: total_tokens, total_input_tokens, Returns a dict with keys: total_tokens, total_input_tokens,
@@ -82,14 +82,22 @@ class MemoryRunStore(RunStore):
self._runs[run_id][key] = value self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat() self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running":
for key, value in kwargs.items():
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
async def list_pending(self, *, before=None): async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat() now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now] results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"]) results.sort(key=lambda r: r["created_at"])
return results return results
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in ("success", "error")] statuses = ("success", "error", "running") if include_active else ("success", "error")
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
by_model: dict[str, dict] = {} by_model: dict[str, dict] = {}
for r in completed: for r in completed:
model = r.get("model_name") or "unknown" model = r.get("model_name") or "unknown"
@@ -153,8 +153,6 @@ async def run_agent(
journal = None journal = None
journal = None
# Track whether "events" was requested but skipped # Track whether "events" was requested but skipped
if "events" in requested_modes: if "events" in requested_modes:
logger.info( logger.info(
@@ -177,6 +175,7 @@ async def run_agent(
thread_id=thread_id, thread_id=thread_id,
event_store=event_store, event_store=event_store,
track_token_usage=getattr(run_events_config, "track_token_usage", True), track_token_usage=getattr(run_events_config, "track_token_usage", True),
progress_reporter=lambda snapshot: run_manager.update_run_progress(run_id, **snapshot),
) )
# 1. Mark running # 1. Mark running
@@ -63,6 +63,7 @@ class LocalSandboxProvider(SandboxProvider):
""" """
uses_thread_data_mounts = True uses_thread_data_mounts = True
needs_upload_permission_adjustment = False
def __init__(self, max_cached_threads: int = DEFAULT_MAX_CACHED_THREAD_SANDBOXES): def __init__(self, max_cached_threads: int = DEFAULT_MAX_CACHED_THREAD_SANDBOXES):
"""Initialize the local sandbox provider with static path mappings. """Initialize the local sandbox provider with static path mappings.
@@ -10,6 +10,7 @@ class SandboxProvider(ABC):
"""Abstract base class for sandbox providers""" """Abstract base class for sandbox providers"""
uses_thread_data_mounts: bool = False uses_thread_data_mounts: bool = False
needs_upload_permission_adjustment: bool = True
@abstractmethod @abstractmethod
def acquire(self, thread_id: str | None = None) -> str: def acquire(self, thread_id: str | None = None) -> str:
@@ -218,6 +218,70 @@ class TestBuildPatchedMessagesPatching:
assert mw._build_patched_messages(msgs) is None assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_ids_across_ai_turns_keep_their_own_tool_results(self):
mw = DanglingToolCallMiddleware()
msgs = [
HumanMessage(content="summary", name="summary", additional_kwargs={"hide_from_ui": True}),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:11"),
_tc("web_search", "web_search:12"),
_tc("web_search", "web_search:13"),
]
),
_tool_msg("web_search:11", "web_search"),
_tool_msg("web_search:12", "web_search"),
_tool_msg("web_search:13", "web_search"),
_ai_with_tool_calls(
[
_tc("web_search", "web_search:9"),
_tc("web_search", "web_search:10"),
_tc("web_search", "web_search:11"),
]
),
_tool_msg("web_search:9", "web_search"),
_tool_msg("web_search:10", "web_search"),
_tool_msg("web_search:11", "web_search"),
]
assert mw._build_patched_messages(msgs) is None
def test_reused_tool_call_id_patches_second_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_tool_msg("web_search:11", "web_search"),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "web_search:11"
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_reused_tool_call_id_consumes_later_result_for_first_dangling_occurrence(self):
mw = DanglingToolCallMiddleware()
result = _tool_msg("web_search:11", "web_search")
msgs = [
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
_ai_with_tool_calls([_tc("web_search", "web_search:11")]),
result,
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert patched[1] is result
assert patched[1].status == "success"
assert isinstance(patched[3], ToolMessage)
assert patched[3].tool_call_id == "web_search:11"
assert patched[3].status == "error"
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self): def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware() mw = DanglingToolCallMiddleware()
msgs = [ msgs = [
+104
View File
@@ -714,6 +714,110 @@ class TestExternalUsageRecords:
assert j._subagent_tokens == 0 assert j._subagent_tokens == 0
class TestProgressSnapshots:
@pytest.mark.anyio
async def test_on_llm_end_reports_progress_snapshot(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0,
)
usage = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
j.on_llm_end(_make_llm_response("Answer", usage=usage), run_id=uuid4(), parent_run_id=None, tags=["lead_agent"])
await j.flush()
assert snapshots
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["message_count"] == 1
assert snapshots[-1]["last_ai_message"] == "Answer"
@pytest.mark.anyio
async def test_throttled_progress_flush_emits_trailing_snapshot(self):
snapshots: list[dict] = []
trailing_seen = asyncio.Event()
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
if snapshot["total_tokens"] == 45:
trailing_seen.set()
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=0.01,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(trailing_seen.wait(), timeout=1.0)
await j.flush()
assert len(snapshots) >= 2
assert snapshots[-1]["total_tokens"] == 45
assert snapshots[-1]["llm_call_count"] == 2
assert snapshots[-1]["last_ai_message"] == "Second"
@pytest.mark.anyio
async def test_flush_cancels_delayed_progress_without_final_progress_write(self):
snapshots: list[dict] = []
async def reporter(snapshot: dict) -> None:
snapshots.append(snapshot)
store = MemoryRunEventStore()
j = RunJournal(
"r1",
"t1",
store,
flush_threshold=100,
progress_reporter=reporter,
progress_flush_interval=10.0,
)
j.on_llm_end(
_make_llm_response("First", usage={"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.sleep(0)
assert snapshots[-1]["total_tokens"] == 15
j.on_llm_end(
_make_llm_response("Second", usage={"input_tokens": 20, "output_tokens": 10, "total_tokens": 30}),
run_id=uuid4(),
parent_run_id=None,
tags=["lead_agent"],
)
await asyncio.wait_for(j.flush(), timeout=0.2)
assert snapshots[-1]["total_tokens"] == 15
assert snapshots[-1]["llm_call_count"] == 1
assert snapshots[-1]["last_ai_message"] == "First"
class TestChatModelStartHumanMessage: class TestChatModelStartHumanMessage:
"""Tests for on_chat_model_start extracting the first human message.""" """Tests for on_chat_model_start extracting the first human message."""
+122
View File
@@ -10,6 +10,7 @@ from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository from deerflow.persistence.run import RunRepository
from deerflow.runtime import RunManager, RunStatus from deerflow.runtime import RunManager, RunStatus
from deerflow.runtime.runs.store.base import RunStore
async def _make_repo(tmp_path): async def _make_repo(tmp_path):
@@ -26,6 +27,42 @@ async def _cleanup():
await close_engine() await close_engine()
class _CustomRunStoreWithoutProgress(RunStore):
async def put(self, *args, **kwargs):
return None
async def get(self, *args, **kwargs):
return None
async def list_by_thread(self, *args, **kwargs):
return []
async def update_status(self, *args, **kwargs):
return None
async def delete(self, *args, **kwargs):
return None
async def update_model_name(self, *args, **kwargs):
return None
async def update_run_completion(self, *args, **kwargs):
return None
async def list_pending(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {}
@pytest.mark.anyio
async def test_update_run_progress_defaults_to_noop_for_custom_store():
store = _CustomRunStoreWithoutProgress()
await store.update_run_progress("r1", total_tokens=1)
class TestRunRepository: class TestRunRepository:
@pytest.mark.anyio @pytest.mark.anyio
async def test_put_and_get(self, tmp_path): async def test_put_and_get(self, tmp_path):
@@ -170,6 +207,69 @@ class TestRunRepository:
assert row["total_tokens"] == 100 assert row["total_tokens"] == 100
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_keeps_status_running(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
message_count=2,
last_ai_message="partial answer",
)
row = await repo.get("r1")
assert row["status"] == "running"
assert row["total_tokens"] == 50
assert row["llm_call_count"] == 1
assert row["message_count"] == 2
assert row["last_ai_message"] == "partial answer"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_preserves_omitted_fields(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_progress(
"r1",
total_input_tokens=40,
total_output_tokens=10,
total_tokens=50,
llm_call_count=1,
lead_agent_tokens=30,
subagent_tokens=20,
message_count=2,
)
await repo.update_run_progress("r1", total_tokens=60, last_ai_message="updated")
row = await repo.get("r1")
assert row["total_input_tokens"] == 40
assert row["total_output_tokens"] == 10
assert row["total_tokens"] == 60
assert row["llm_call_count"] == 1
assert row["lead_agent_tokens"] == 30
assert row["subagent_tokens"] == 20
assert row["message_count"] == 2
assert row["last_ai_message"] == "updated"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_progress_skips_terminal_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion("r1", status="success", total_tokens=100, llm_call_count=1)
await repo.update_run_progress("r1", total_tokens=200, llm_call_count=2)
row = await repo.get("r1")
assert row["status"] == "success"
assert row["total_tokens"] == 100
assert row["llm_call_count"] == 1
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path): async def test_aggregate_tokens_by_thread_counts_completed_runs_only(self, tmp_path):
repo = await _make_repo(tmp_path) repo = await _make_repo(tmp_path)
@@ -225,6 +325,28 @@ class TestRunRepository:
} }
await _cleanup() await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_can_include_active_runs(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("success-run", thread_id="t1", status="running")
await repo.update_run_completion("success-run", status="success", total_tokens=100, lead_agent_tokens=100)
await repo.put("running-run", thread_id="t1", status="running")
await repo.update_run_progress("running-run", total_tokens=25, lead_agent_tokens=20, subagent_tokens=5)
without_active = await repo.aggregate_tokens_by_thread("t1")
with_active = await repo.aggregate_tokens_by_thread("t1", include_active=True)
assert without_active["total_tokens"] == 100
assert without_active["total_runs"] == 1
assert with_active["total_tokens"] == 125
assert with_active["total_runs"] == 2
assert with_active["by_caller"] == {
"lead_agent": 120,
"subagent": 5,
"middleware": 0,
}
await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_list_by_thread_ordered_desc(self, tmp_path): async def test_list_by_thread_ordered_desc(self, tmp_path):
"""list_by_thread returns newest first.""" """list_by_thread returns newest first."""
+27
View File
@@ -53,3 +53,30 @@ def test_thread_token_usage_returns_stable_shape():
}, },
} }
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1") run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1")
def test_thread_token_usage_can_include_active_runs():
run_store = MagicMock()
run_store.aggregate_tokens_by_thread = AsyncMock(
return_value={
"total_tokens": 175,
"total_input_tokens": 120,
"total_output_tokens": 55,
"total_runs": 3,
"by_model": {"unknown": {"tokens": 175, "runs": 3}},
"by_caller": {
"lead_agent": 145,
"subagent": 25,
"middleware": 5,
},
},
)
app = _make_app(run_store)
with TestClient(app) as client:
response = client.get("/api/threads/thread-1/token-usage?include_active=true")
assert response.status_code == 200
assert response.json()["total_tokens"] == 175
assert response.json()["total_runs"] == 3
run_store.aggregate_tokens_by_thread.assert_awaited_once_with("thread-1", include_active=True)
+56
View File
@@ -219,6 +219,7 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
provider = MagicMock() provider = MagicMock()
provider.uses_thread_data_mounts = True provider.uses_thread_data_mounts = True
provider.needs_upload_permission_adjustment = False
provider.acquire.return_value = "local" provider.acquire.return_value = "local"
sandbox = MagicMock() sandbox = MagicMock()
provider.get.return_value = sandbox provider.get.return_value = sandbox
@@ -228,12 +229,14 @@ def test_upload_files_does_not_adjust_permissions_for_local_sandbox(tmp_path):
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir), patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "get_sandbox_provider", return_value=provider), patch.object(uploads, "get_sandbox_provider", return_value=provider),
patch.object(uploads, "_make_file_sandbox_writable") as make_writable, patch.object(uploads, "_make_file_sandbox_writable") as make_writable,
patch.object(uploads, "_make_file_sandbox_readable") as make_readable,
): ):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads")) file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace())) result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-local", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True assert result.success is True
make_writable.assert_not_called() make_writable.assert_not_called()
make_readable.assert_not_called()
def test_upload_files_acquires_non_local_sandbox_before_writing(tmp_path): def test_upload_files_acquires_non_local_sandbox_before_writing(tmp_path):
@@ -431,6 +434,59 @@ def test_make_file_sandbox_writable_skips_symlinks(tmp_path):
chmod.assert_not_called() chmod.assert_not_called()
def test_make_file_sandbox_readable_adds_read_bits_for_regular_files(tmp_path):
file_path = tmp_path / "data.csv"
file_path.write_bytes(b"csv-data")
# Simulate the 0o600 permissions set by open_upload_file_no_symlink
file_path.chmod(0o600)
uploads._make_file_sandbox_readable(file_path)
updated_mode = stat.S_IMODE(file_path.stat().st_mode)
assert updated_mode & stat.S_IRUSR
assert updated_mode & stat.S_IRGRP
assert updated_mode & stat.S_IROTH
def test_make_file_sandbox_readable_skips_symlinks(tmp_path):
file_path = tmp_path / "target-link.txt"
file_path.write_text("hello", encoding="utf-8")
symlink_stat = MagicMock(st_mode=stat.S_IFLNK)
with (
patch.object(uploads.os, "lstat", return_value=symlink_stat),
patch.object(uploads.os, "chmod") as chmod,
):
uploads._make_file_sandbox_readable(file_path)
chmod.assert_not_called()
def test_upload_files_adjusts_read_permissions_for_mounted_non_local_sandbox(tmp_path):
thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True)
# AIO sandbox with LocalContainerBackend: uses_thread_data_mounts=True
# but needs_upload_permission_adjustment=True (default)
provider = MagicMock()
provider.uses_thread_data_mounts = True
provider.needs_upload_permission_adjustment = True
with (
patch.object(uploads, "get_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "ensure_uploads_dir", return_value=thread_uploads_dir),
patch.object(uploads, "get_sandbox_provider", return_value=provider),
patch.object(uploads, "_make_file_sandbox_readable") as make_readable,
):
file = UploadFile(filename="notes.txt", file=BytesIO(b"hello uploads"))
result = asyncio.run(call_unwrapped(uploads.upload_files, "thread-aio", request=MagicMock(), files=[file], config=SimpleNamespace()))
assert result.success is True
make_readable.assert_called_once()
called_path = make_readable.call_args[0][0]
assert called_path.name == "notes.txt"
def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path): def test_upload_files_rejects_dotdot_and_dot_filenames(tmp_path):
thread_uploads_dir = tmp_path / "uploads" thread_uploads_dir = tmp_path / "uploads"
thread_uploads_dir.mkdir(parents=True) thread_uploads_dir.mkdir(parents=True)