mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 16:35:59 +00:00
fix: harden run finalization persistence (#3155)
* fix: harden run finalization persistence * style: format gateway recovery test * fix: align run repository return types * fix: harden completion recovery follow-up
This commit is contained in:
@@ -59,7 +59,12 @@ class RunStore(abc.ABC):
|
||||
status: str,
|
||||
*,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
) -> bool | None:
|
||||
"""Update a run status.
|
||||
|
||||
Returns ``False`` when the store can prove no row was updated. Older or
|
||||
lightweight stores may return ``None`` when they cannot report rowcount.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -92,7 +97,11 @@ class RunStore(abc.ABC):
|
||||
last_ai_message: str | None = None,
|
||||
first_human_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
) -> bool | None:
|
||||
"""Persist final completion fields.
|
||||
|
||||
Returns ``False`` when the store can prove no row was updated.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def update_run_progress(
|
||||
@@ -117,6 +126,11 @@ class RunStore(abc.ABC):
|
||||
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def list_inflight(self, *, before: str | None = None) -> list[dict[str, Any]]:
|
||||
"""Return persisted runs that are still ``pending`` or ``running``."""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
"""Aggregate token usage for completed runs in a thread.
|
||||
|
||||
@@ -65,6 +65,8 @@ class MemoryRunStore(RunStore):
|
||||
if error is not None:
|
||||
self._runs[run_id]["error"] = error
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def update_model_name(self, run_id, model_name):
|
||||
if run_id in self._runs:
|
||||
@@ -81,6 +83,8 @@ class MemoryRunStore(RunStore):
|
||||
if value is not None:
|
||||
self._runs[run_id][key] = value
|
||||
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||
return True
|
||||
return False
|
||||
|
||||
async def update_run_progress(self, run_id, **kwargs):
|
||||
if run_id in self._runs and self._runs[run_id].get("status") == "running":
|
||||
@@ -95,6 +99,12 @@ class MemoryRunStore(RunStore):
|
||||
results.sort(key=lambda r: r["created_at"])
|
||||
return results
|
||||
|
||||
async def list_inflight(self, *, before=None):
|
||||
now = before or datetime.now(UTC).isoformat()
|
||||
results = [r for r in self._runs.values() if r["status"] in ("pending", "running") and r["created_at"] <= now]
|
||||
results.sort(key=lambda r: r["created_at"])
|
||||
return results
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str, *, include_active: bool = False) -> dict[str, Any]:
|
||||
statuses = ("success", "error", "running") if include_active else ("success", "error")
|
||||
completed = [r for r in self._runs.values() if r["thread_id"] == thread_id and r.get("status") in statuses]
|
||||
|
||||
Reference in New Issue
Block a user