fix(runtime): make rollback restore checkpoint supersede newer checkpoints (#2582)
* Restore rollback checkpoints with fresh ids * Tighten rollback checkpoint tests and imports * Update test_run_worker_rollback.py --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -3,6 +3,8 @@ from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, call
|
||||
|
||||
import pytest
|
||||
from langgraph.checkpoint.base import empty_checkpoint
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
from deerflow.runtime.runs.manager import RunManager
|
||||
from deerflow.runtime.runs.schemas import RunStatus
|
||||
@@ -16,6 +18,14 @@ class FakeCheckpointer:
|
||||
self.aput_writes = AsyncMock()
|
||||
|
||||
|
||||
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
|
||||
checkpoint = empty_checkpoint()
|
||||
checkpoint["id"] = checkpoint_id
|
||||
checkpoint["channel_values"] = {"messages": messages}
|
||||
checkpoint["channel_versions"] = {"messages": version}
|
||||
return checkpoint
|
||||
|
||||
|
||||
def test_build_runtime_context_includes_app_config_when_present():
|
||||
app_config = object()
|
||||
|
||||
@@ -110,16 +120,16 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
)
|
||||
|
||||
checkpointer.adelete_thread.assert_not_awaited()
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{
|
||||
"id": "ckpt-1",
|
||||
"channel_versions": {"messages": 3},
|
||||
"channel_values": {"messages": ["before"]},
|
||||
},
|
||||
{"source": "input"},
|
||||
{"messages": 3},
|
||||
)
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert "channel_versions" in restored_checkpoint
|
||||
assert "channel_values" in restored_checkpoint
|
||||
assert restored_checkpoint["channel_versions"] == {"messages": 3}
|
||||
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert restored_metadata == {"source": "input"}
|
||||
assert new_versions == {"messages": 3}
|
||||
assert checkpointer.aput_writes.await_args_list == [
|
||||
call(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
|
||||
@@ -134,6 +144,40 @@ async def test_rollback_restores_snapshot_without_deleting_thread():
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
|
||||
checkpointer = InMemorySaver()
|
||||
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
|
||||
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
|
||||
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
|
||||
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
|
||||
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
|
||||
|
||||
await _rollback_to_pre_run_checkpoint(
|
||||
checkpointer=checkpointer,
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
pre_run_checkpoint_id="0001",
|
||||
pre_run_snapshot={
|
||||
"checkpoint_ns": "",
|
||||
"checkpoint": before_checkpoint,
|
||||
"metadata": {"step": 1},
|
||||
"pending_writes": [("task-before", "messages", "pending-before")],
|
||||
},
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
latest = checkpointer.get_tuple(thread_config)
|
||||
|
||||
assert latest is not None
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0001"
|
||||
assert latest.config["configurable"]["checkpoint_id"] != "0002"
|
||||
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
|
||||
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
|
||||
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_rollback_deletes_thread_when_no_snapshot_exists():
|
||||
checkpointer = FakeCheckpointer(put_result=None)
|
||||
@@ -194,12 +238,13 @@ async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
|
||||
snapshot_capture_failed=False,
|
||||
)
|
||||
|
||||
checkpointer.aput.assert_awaited_once_with(
|
||||
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}},
|
||||
{"id": "ckpt-1", "channel_versions": {}},
|
||||
{},
|
||||
{},
|
||||
)
|
||||
checkpointer.aput.assert_awaited_once()
|
||||
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
|
||||
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
|
||||
assert restored_checkpoint["id"] != "ckpt-1"
|
||||
assert restored_checkpoint["channel_versions"] == {}
|
||||
assert restored_metadata == {}
|
||||
assert new_versions == {}
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
Reference in New Issue
Block a user