fix(runtime): make rollback restore checkpoint supersede newer checkpoints (#2582)
* Restore rollback checkpoints with fresh ids * Tighten rollback checkpoint tests and imports * Update test_run_worker_rollback.py --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -23,6 +23,8 @@ from dataclasses import dataclass, field
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
@@ -442,6 +444,12 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
if checkpoint_to_restore.get("id") is None:
|
||||
logger.warning("Run %s rollback skipped: pre-run checkpoint has no checkpoint id", run_id)
|
||||
return
|
||||
restore_marker = _new_checkpoint_marker()
|
||||
checkpoint_to_restore = {
|
||||
**checkpoint_to_restore,
|
||||
"id": restore_marker["id"],
|
||||
"ts": restore_marker["ts"],
|
||||
}
|
||||
metadata = pre_run_snapshot.get("metadata", {})
|
||||
metadata_to_restore = metadata if isinstance(metadata, dict) else {}
|
||||
raw_checkpoint_ns = pre_run_snapshot.get("checkpoint_ns")
|
||||
@@ -493,6 +501,11 @@ async def _rollback_to_pre_run_checkpoint(
|
||||
)
|
||||
|
||||
|
||||
def _new_checkpoint_marker() -> dict[str, str]:
|
||||
marker = empty_checkpoint()
|
||||
return {"id": marker["id"], "ts": marker["ts"]}
|
||||
|
||||
|
||||
def _lg_mode_to_sse_event(mode: str) -> str:
|
||||
"""Map LangGraph internal stream_mode name to SSE event name.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user