fix: harden run finalization persistence (#3155)

* fix: harden run finalization persistence

* style: format gateway recovery test

* fix: align run repository return types

* fix: harden completion recovery follow-up
This commit is contained in:
AochenShen99
2026-05-23 00:09:06 +08:00
committed by GitHub
parent f0bae28636
commit 66d6a6a4e8
8 changed files with 755 additions and 56 deletions
@@ -59,7 +59,12 @@ class RunStore(abc.ABC):
status: str,
*,
error: str | None = None,
) -> None:
) -> bool | None:
"""Update a run status.
Returns ``False`` when the store can prove no row was updated. Older or
lightweight stores may return ``None`` when they cannot report rowcount.
"""
pass
@abc.abstractmethod
@@ -92,7 +97,11 @@ class RunStore(abc.ABC):
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None:
) -> bool | None:
"""Persist final completion fields.
Returns ``False`` when the store can prove no row was updated.
"""
pass
async def update_run_progress(
@@ -117,6 +126,11 @@ class RunStore(abc.ABC):
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
pass
@abc.abstractmethod
async def list_inflight(self, *, before: str | None = None) -> list[dict[str, Any]]:
"""Return persisted runs that are still ``pending`` or ``running``."""
pass
@abc.abstractmethod
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
"""Aggregate token usage for completed runs in a thread.
@@ -65,6 +65,8 @@ class MemoryRunStore(RunStore):
if error is not None:
self._runs[run_id]["error"] = error
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
return True
return False
async def update_model_name(self, run_id, model_name):
if run_id in self._runs:
@@ -81,6 +83,8 @@ class MemoryRunStore(RunStore):
if value is not None:
self._runs[run_id][key] = value
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
return True
return False
async def update_run_progress(self, run_id, **kwargs):
if run_id in self._runs and self._runs[run_id].get("status") == "running":
@@ -95,6 +99,12 @@ class MemoryRunStore(RunStore):
results.sort(key=lambda r: r["created_at"])
return results
async def list_inflight(self, *, before=None):
now = before or datetime.now(UTC).isoformat()
results = [r for r in self._runs.values() if r["status"] in ("pending", "running") and r["created_at"] <= now]
results.sort(key=lambda r: r["created_at"])
return results
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
statuses = ("success", "error", "running") if include_active else ("success", "error")
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]