fix: harden run finalization persistence (#3155)

* fix: harden run finalization persistence

* style: format gateway recovery test

* fix: align run repository return types

* fix: harden completion recovery follow-up
This commit is contained in:
AochenShen99
2026-05-23 00:09:06 +08:00
committed by GitHub
parent f0bae28636
commit 66d6a6a4e8
8 changed files with 755 additions and 56 deletions
+47 -2
View File
@@ -52,6 +52,9 @@ class _CustomRunStoreWithoutProgress(RunStore):
async def list_pending(self, *args, **kwargs):
return []
async def list_inflight(self, *args, **kwargs):
return []
async def aggregate_tokens_by_thread(self, *args, **kwargs):
return {}
@@ -75,6 +78,19 @@ class TestRunRepository:
assert row["status"] == "pending"
await _cleanup()
@pytest.mark.anyio
async def test_put_is_idempotent_for_retried_writes(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", assistant_id="old-agent", status="pending")
await repo.put("r1", thread_id="t1", assistant_id="new-agent", status="running", error="retry")
row = await repo.get("r1")
assert row["assistant_id"] == "new-agent"
assert row["status"] == "running"
assert row["error"] == "retry"
await _cleanup()
@pytest.mark.anyio
async def test_get_missing_returns_none(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -85,11 +101,19 @@ class TestRunRepository:
async def test_update_status(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1")
await repo.update_status("r1", "running")
updated = await repo.update_status("r1", "running")
row = await repo.get("r1")
assert updated is True
assert row["status"] == "running"
await _cleanup()
@pytest.mark.anyio
async def test_update_status_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_status("missing", "error", error="lost")
assert updated is False
await _cleanup()
@pytest.mark.anyio
async def test_update_status_with_error(self, tmp_path):
repo = await _make_repo(tmp_path)
@@ -146,11 +170,24 @@ class TestRunRepository:
assert all(r["status"] == "pending" for r in pending)
await _cleanup()
@pytest.mark.anyio
async def test_list_inflight_returns_pending_and_running_before_cutoff(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("pending-old", thread_id="t1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await repo.put("running-old", thread_id="t1", status="running", created_at="2026-01-01T00:00:01+00:00")
await repo.put("success-old", thread_id="t1", status="success", created_at="2026-01-01T00:00:02+00:00")
await repo.put("pending-new", thread_id="t1", status="pending", created_at="2026-01-01T00:00:03+00:00")
inflight = await repo.list_inflight(before="2026-01-01T00:00:02+00:00")
assert [row["run_id"] for row in inflight] == ["pending-old", "running-old"]
await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion(self, tmp_path):
repo = await _make_repo(tmp_path)
await repo.put("r1", thread_id="t1", status="running")
await repo.update_run_completion(
updated = await repo.update_run_completion(
"r1",
status="success",
total_input_tokens=100,
@@ -165,6 +202,7 @@ class TestRunRepository:
first_human_message="What is the meaning?",
)
row = await repo.get("r1")
assert updated is True
assert row["status"] == "success"
assert row["total_tokens"] == 150
assert row["llm_call_count"] == 2
@@ -174,6 +212,13 @@ class TestRunRepository:
assert row["first_human_message"] == "What is the meaning?"
await _cleanup()
@pytest.mark.anyio
async def test_update_run_completion_returns_false_for_missing_row(self, tmp_path):
repo = await _make_repo(tmp_path)
updated = await repo.update_run_completion("missing", status="error", total_tokens=1)
assert updated is False
await _cleanup()
@pytest.mark.anyio
async def test_metadata_preserved(self, tmp_path):
repo = await _make_repo(tmp_path)