fix: harden run finalization persistence (#3155)

* fix: harden run finalization persistence

* style: format gateway recovery test

* fix: align run repository return types

* fix: harden completion recovery follow-up
This commit is contained in:
AochenShen99
2026-05-23 00:09:06 +08:00
committed by GitHub
parent f0bae28636
commit 66d6a6a4e8
8 changed files with 755 additions and 56 deletions
@@ -94,25 +94,35 @@ class RunRepository(RunStore):
created_at=None,
follow_up_to_run_id=None,
):
"""Insert or update a run row.
``RunManager`` retries ``put`` after transient SQLite failures. Making
this operation idempotent prevents a successful-but-unacknowledged first
commit from turning the retry into a primary-key failure.
"""
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
now = datetime.now(UTC)
row = RunRow(
run_id=run_id,
thread_id=thread_id,
assistant_id=assistant_id,
user_id=resolved_user_id,
model_name=self._normalize_model_name(model_name),
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
kwargs_json=self._safe_json(kwargs) or {},
error=error,
follow_up_to_run_id=follow_up_to_run_id,
created_at=datetime.fromisoformat(created_at) if created_at else now,
updated_at=now,
)
created = datetime.fromisoformat(created_at) if created_at else now
values = {
"thread_id": thread_id,
"assistant_id": assistant_id,
"user_id": resolved_user_id,
"model_name": self._normalize_model_name(model_name),
"status": status,
"multitask_strategy": multitask_strategy,
"metadata_json": self._safe_json(metadata) or {},
"kwargs_json": self._safe_json(kwargs) or {},
"error": error,
"follow_up_to_run_id": follow_up_to_run_id,
"updated_at": now,
}
async with self._sf() as session:
session.add(row)
row = await session.get(RunRow, run_id)
if row is None:
session.add(RunRow(run_id=run_id, created_at=created, **values))
else:
for key, value in values.items():
setattr(row, key, value)
await session.commit()
async def get(
@@ -146,13 +156,14 @@ class RunRepository(RunStore):
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_status(self, run_id, status, *, error=None):
async def update_status(self, run_id, status, *, error=None) -> bool:
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
if error is not None:
values["error"] = error
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
return result.rowcount != 0
async def update_model_name(self, run_id, model_name):
async with self._sf() as session:
@@ -187,6 +198,26 @@ class RunRepository(RunStore):
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_inflight(self, *, before=None):
"""Return persisted active runs for startup recovery."""
if before is None:
before_dt = datetime.now(UTC)
elif isinstance(before, datetime):
before_dt = before
else:
before_dt = datetime.fromisoformat(before)
stmt = (
select(RunRow)
.where(
RunRow.status.in_(("pending", "running")),
RunRow.created_at <= before_dt,
)
.order_by(RunRow.created_at.asc())
)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_run_completion(
self,
run_id: str,
@@ -203,8 +234,11 @@ class RunRepository(RunStore):
last_ai_message: str | None = None,
first_human_message: str | None = None,
error: str | None = None,
) -> None:
"""Update status + token usage + convenience fields on run completion."""
) -> bool:
"""Update status + token usage + convenience fields on run completion.
Returns ``False`` when no run row matched the requested ``run_id``.
"""
values: dict[str, Any] = {
"status": status,
"total_input_tokens": total_input_tokens,
@@ -224,8 +258,9 @@ class RunRepository(RunStore):
if error is not None:
values["error"] = error
async with self._sf() as session:
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
await session.commit()
return result.rowcount != 0
async def update_run_progress(
self,