mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
fix: harden run finalization persistence (#3155)
* fix: harden run finalization persistence * style: format gateway recovery test * fix: align run repository return types * fix: harden completion recovery follow-up
This commit is contained in:
@@ -94,25 +94,35 @@ class RunRepository(RunStore):
|
||||
created_at=None,
|
||||
follow_up_to_run_id=None,
|
||||
):
|
||||
"""Insert or update a run row.
|
||||
|
||||
``RunManager`` retries ``put`` after transient SQLite failures. Making
|
||||
this operation idempotent prevents a successful-but-unacknowledged first
|
||||
commit from turning the retry into a primary-key failure.
|
||||
"""
|
||||
resolved_user_id = resolve_user_id(user_id, method_name="RunRepository.put")
|
||||
now = datetime.now(UTC)
|
||||
row = RunRow(
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
user_id=resolved_user_id,
|
||||
model_name=self._normalize_model_name(model_name),
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
kwargs_json=self._safe_json(kwargs) or {},
|
||||
error=error,
|
||||
follow_up_to_run_id=follow_up_to_run_id,
|
||||
created_at=datetime.fromisoformat(created_at) if created_at else now,
|
||||
updated_at=now,
|
||||
)
|
||||
created = datetime.fromisoformat(created_at) if created_at else now
|
||||
values = {
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"user_id": resolved_user_id,
|
||||
"model_name": self._normalize_model_name(model_name),
|
||||
"status": status,
|
||||
"multitask_strategy": multitask_strategy,
|
||||
"metadata_json": self._safe_json(metadata) or {},
|
||||
"kwargs_json": self._safe_json(kwargs) or {},
|
||||
"error": error,
|
||||
"follow_up_to_run_id": follow_up_to_run_id,
|
||||
"updated_at": now,
|
||||
}
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
row = await session.get(RunRow, run_id)
|
||||
if row is None:
|
||||
session.add(RunRow(run_id=run_id, created_at=created, **values))
|
||||
else:
|
||||
for key, value in values.items():
|
||||
setattr(row, key, value)
|
||||
await session.commit()
|
||||
|
||||
async def get(
|
||||
@@ -146,13 +156,14 @@ class RunRepository(RunStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_status(self, run_id, status, *, error=None):
|
||||
async def update_status(self, run_id, status, *, error=None) -> bool:
|
||||
values: dict[str, Any] = {"status": status, "updated_at": datetime.now(UTC)}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
return result.rowcount != 0
|
||||
|
||||
async def update_model_name(self, run_id, model_name):
|
||||
async with self._sf() as session:
|
||||
@@ -187,6 +198,26 @@ class RunRepository(RunStore):
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_inflight(self, *, before=None):
|
||||
"""Return persisted active runs for startup recovery."""
|
||||
if before is None:
|
||||
before_dt = datetime.now(UTC)
|
||||
elif isinstance(before, datetime):
|
||||
before_dt = before
|
||||
else:
|
||||
before_dt = datetime.fromisoformat(before)
|
||||
stmt = (
|
||||
select(RunRow)
|
||||
.where(
|
||||
RunRow.status.in_(("pending", "running")),
|
||||
RunRow.created_at <= before_dt,
|
||||
)
|
||||
.order_by(RunRow.created_at.asc())
|
||||
)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
@@ -203,8 +234,11 @@ class RunRepository(RunStore):
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
"""Update status + token usage + convenience fields on run completion."""
|
||||
) -> bool:
|
||||
"""Update status + token usage + convenience fields on run completion.
|
||||
|
||||
Returns ``False`` when no run row matched the requested ``run_id``.
|
||||
"""
|
||||
values: dict[str, Any] = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
@@ -224,8 +258,9 @@ class RunRepository(RunStore):
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
async with self._sf() as session:
|
||||
await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
result = await session.execute(update(RunRow).where(RunRow.run_id == run_id).values(**values))
|
||||
await session.commit()
|
||||
return result.rowcount != 0
|
||||
|
||||
async def update_run_progress(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user