Files
deer-flow/backend/tests/test_run_worker_rollback.py
T
He Wang c810e9f809 fix(harness)!: hydrate runs from RunStore and persist interrupted status (#2932)
* fix(harness): hydrate run history from RunStore and persist cancellation status

fix:
- Make RunManager.get() async and hydrate from RunStore when in-memory record is missing
- Merge store rows into list_by_thread() with in-memory precedence for active runs
- Persist interrupted status to RunStore in cancel() and create_or_reject(interrupt|rollback)
- Extract _persist_status() to reuse the best-effort store update pattern
- Await run_mgr.get() in all gateway endpoints
- Return 409 with distinct message for store-only runs not active on current worker

Closes #2812, Closes #2813

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* fix(harness): consistent sort and guarded hydration in RunManager

fix:
- list_by_thread() now sorts by created_at desc (newest first) even when
  no RunStore is configured, matching the store-backed code path
- guard _record_from_store() call sites in get() and list_by_thread()
  with best-effort error handling so a single malformed store row cannot
  turn read paths into 500s

test:
- update test_list_by_thread assertion to expect newest-first order
- seed MemoryRunStore via public put() API instead of writing to _runs

* fix(harness): guard store-only runs from streaming and fix get() TOCTOU

Add RunRecord.store_only flag set by _record_from_store so callers can
distinguish hydrated history from live in-memory runs.  join_run and
stream_existing_run (action=None) now return 409 instead of hanging
forever on an empty MemoryStreamBridge channel.

Re-check _runs under lock after the store await in RunManager.get() so a
concurrent create() that lands between the two checks returns the
authoritative in-memory record rather than a stale store-hydrated copy.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

* fix(harness): reorder bridge fetch in join_run and make list_by_thread limit explicit

Move get_stream_bridge() after the store_only guard in join_run so a
missing bridge cannot produce 503 for historical runs before the 409
guard fires.

Add limit parameter to RunManager.list_by_thread (default 100, matching
the store's page size) and pass it explicitly to the store call.
Update docstring to document the limit instead of claiming all runs are
returned.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

* fix(harness): cap list_by_thread result to limit after merge

Apply [:limit] to all return paths in list_by_thread so the method
consistently returns at most limit records regardless of how many
in-memory runs exist, making the limit parameter a true upper bound
on the response size rather than just a store-query hint.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

* fix `list_by_thread` docstring

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix(runtime): add update_model_name to RunStore to prevent SQL integrity errors

RunManager.update_model_name() was calling _persist_to_store() which uses
RunStore.put(), but RunRepository.put() is insert-only. This caused integrity
errors when updating model_name for existing runs in SQL-backed stores.

fix:
- Add abstract update_model_name method to RunStore base class
- Implement update_model_name in MemoryRunStore
- Implement update_model_name in RunRepository with proper normalization
- Add _persist_model_name helper in RunManager
- Update RunManager.update_model_name to use the new method

test:
- Add tests for update_model_name functionality
- Add integration tests for RunManager with SQL-backed store

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>

* fix(runtime): handle NULL status/on_disconnect in _record_from_store

`dict.get(key, default)` only uses the default when the key is absent,
so a SQL row with an explicit NULL status would pass `None` to
`RunStatus(None)` and raise, breaking hydration for otherwise valid rows.
Switch to `row.get(...) or fallback` so both missing and NULL values
get a safe default. Add tests for get() and list_by_thread() with a
NULL status row to prevent regression.

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

* fix(runs): address PR review feedback on store consistency changes

- Fix list_by_thread limit semantics: pass store_limit = max(0, limit - len(memory_records)) to store so newer store records are not crowded out by in-memory records
- Remove dead code: cancelled guard after raise is always True, simplify to if wait and record.task
- Document _record_from_store NULL fallback policy (status→pending, on_disconnect→cancel) in docstring

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.7 <noreply@anthropic.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-18 22:25:02 +08:00

387 lines
15 KiB
Python

import asyncio
from types import SimpleNamespace
from unittest.mock import AsyncMock, call
import pytest
from langgraph.checkpoint.base import empty_checkpoint
from langgraph.checkpoint.memory import InMemorySaver
from deerflow.runtime.runs.manager import RunManager
from deerflow.runtime.runs.schemas import RunStatus
from deerflow.runtime.runs.worker import RunContext, _agent_factory_supports_app_config, _build_runtime_context, _install_runtime_context, _rollback_to_pre_run_checkpoint, run_agent
class FakeCheckpointer:
def __init__(self, *, put_result):
self.adelete_thread = AsyncMock()
self.aput = AsyncMock(return_value=put_result)
self.aput_writes = AsyncMock()
def _make_checkpoint(checkpoint_id: str, messages: list[str], version: int):
checkpoint = empty_checkpoint()
checkpoint["id"] = checkpoint_id
checkpoint["channel_values"] = {"messages": messages}
checkpoint["channel_versions"] = {"messages": version}
return checkpoint
def test_build_runtime_context_includes_app_config_when_present():
app_config = object()
context = _build_runtime_context("thread-1", "run-1", None, app_config)
assert context["thread_id"] == "thread-1"
assert context["run_id"] == "run-1"
assert context["app_config"] is app_config
def test_install_runtime_context_preserves_existing_thread_id_and_threads_app_config():
app_config = object()
config = {"context": {"thread_id": "caller-thread"}}
_install_runtime_context(
config,
{
"thread_id": "record-thread",
"run_id": "run-1",
"app_config": app_config,
},
)
assert config["context"]["thread_id"] == "caller-thread"
assert config["context"]["run_id"] == "run-1"
assert config["context"]["app_config"] is app_config
@pytest.mark.anyio
async def test_run_agent_threads_explicit_app_config_into_config_only_factory():
run_manager = RunManager()
record = await run_manager.create("thread-1")
bridge = SimpleNamespace(
publish=AsyncMock(),
publish_end=AsyncMock(),
cleanup=AsyncMock(),
)
app_config = object()
captured: dict[str, object] = {}
class DummyAgent:
async def astream(self, graph_input, config=None, stream_mode=None, subgraphs=False):
captured["astream_context"] = config["context"]
yield {"messages": []}
def factory(*, config):
captured["factory_context"] = config["context"]
return DummyAgent()
await run_agent(
bridge,
run_manager,
record,
ctx=RunContext(checkpointer=None, app_config=app_config),
agent_factory=factory,
graph_input={},
config={},
)
await asyncio.sleep(0)
assert captured["factory_context"]["app_config"] is app_config
assert captured["astream_context"]["app_config"] is app_config
fetched = await run_manager.get(record.run_id)
assert fetched is not None
assert fetched.status == RunStatus.success
bridge.publish_end.assert_awaited_once_with(record.run_id)
bridge.cleanup.assert_awaited_once_with(record.run_id, delay=60)
@pytest.mark.anyio
async def test_rollback_restores_snapshot_without_deleting_thread():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {
"id": "ckpt-1",
"channel_versions": {"messages": 3},
"channel_values": {"messages": ["before"]},
},
"metadata": {"source": "input"},
"pending_writes": [
("task-a", "messages", {"content": "first"}),
("task-a", "status", "done"),
("task-b", "events", {"type": "tool"}),
],
},
snapshot_capture_failed=False,
)
checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once()
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
assert restored_checkpoint["id"] != "ckpt-1"
assert "channel_versions" in restored_checkpoint
assert "channel_values" in restored_checkpoint
assert restored_checkpoint["channel_versions"] == {"messages": 3}
assert restored_checkpoint["channel_values"] == {"messages": ["before"]}
assert restored_metadata == {"source": "input"}
assert new_versions == {"messages": 3}
assert checkpointer.aput_writes.await_args_list == [
call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
[("messages", {"content": "first"}), ("status", "done")],
task_id="task-a",
),
call(
{"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}},
[("events", {"type": "tool"})],
task_id="task-b",
),
]
@pytest.mark.anyio
async def test_rollback_restored_checkpoint_becomes_latest_with_real_checkpointer():
checkpointer = InMemorySaver()
thread_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
before_checkpoint = _make_checkpoint("0001", ["before"], 1)
before_config = checkpointer.put(thread_config, before_checkpoint, {"step": 1}, {"messages": 1})
after_checkpoint = _make_checkpoint("0002", ["after"], 2)
after_config = checkpointer.put(before_config, after_checkpoint, {"step": 2}, {"messages": 2})
checkpointer.put_writes(after_config, [("messages", "pending-after")], task_id="task-after")
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="0001",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": before_checkpoint,
"metadata": {"step": 1},
"pending_writes": [("task-before", "messages", "pending-before")],
},
snapshot_capture_failed=False,
)
latest = checkpointer.get_tuple(thread_config)
assert latest is not None
assert latest.config["configurable"]["checkpoint_id"] != "0001"
assert latest.config["configurable"]["checkpoint_id"] != "0002"
assert latest.checkpoint["channel_values"] == {"messages": ["before"]}
assert latest.pending_writes == [("task-before", "messages", "pending-before")]
assert ("task-after", "messages", "pending-after") not in latest.pending_writes
@pytest.mark.anyio
async def test_rollback_deletes_thread_when_no_snapshot_exists():
checkpointer = FakeCheckpointer(put_result=None)
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id=None,
pre_run_snapshot=None,
snapshot_capture_failed=False,
)
checkpointer.adelete_thread.assert_awaited_once_with("thread-1")
checkpointer.aput.assert_not_awaited()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_raises_when_restore_config_has_no_checkpoint_id():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}})
with pytest.raises(RuntimeError, match="did not return checkpoint_id"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [("task-a", "messages", "value")],
},
snapshot_capture_failed=False,
)
checkpointer.adelete_thread.assert_not_awaited()
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_normalizes_none_checkpoint_ns_to_root_namespace():
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": None,
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [],
},
snapshot_capture_failed=False,
)
checkpointer.aput.assert_awaited_once()
restore_config, restored_checkpoint, restored_metadata, new_versions = checkpointer.aput.await_args.args
assert restore_config == {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}}
assert restored_checkpoint["id"] != "ckpt-1"
assert restored_checkpoint["channel_versions"] == {}
assert restored_metadata == {}
assert new_versions == {}
@pytest.mark.anyio
async def test_rollback_raises_on_malformed_pending_write_not_a_tuple():
"""pending_writes containing a non-3-tuple item should raise RuntimeError."""
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
with pytest.raises(RuntimeError, match="rollback failed: pending_write is not a 3-tuple"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [
("task-a", "messages", "valid"), # valid
["only", "two"], # malformed: only 2 elements
],
},
snapshot_capture_failed=False,
)
# aput succeeded but aput_writes should not be called due to malformed data
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_raises_on_malformed_pending_write_non_string_channel():
"""pending_writes containing a non-string channel should raise RuntimeError."""
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
with pytest.raises(RuntimeError, match="rollback failed: pending_write has non-string channel"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [
("task-a", 123, "value"), # malformed: channel is not a string
],
},
snapshot_capture_failed=False,
)
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_not_awaited()
@pytest.mark.anyio
async def test_rollback_propagates_aput_writes_failure():
"""If aput_writes fails, the exception should propagate (not be swallowed)."""
checkpointer = FakeCheckpointer(put_result={"configurable": {"thread_id": "thread-1", "checkpoint_ns": "", "checkpoint_id": "restored-1"}})
# Simulate aput_writes failure
checkpointer.aput_writes.side_effect = RuntimeError("Database connection lost")
with pytest.raises(RuntimeError, match="Database connection lost"):
await _rollback_to_pre_run_checkpoint(
checkpointer=checkpointer,
thread_id="thread-1",
run_id="run-1",
pre_run_checkpoint_id="ckpt-1",
pre_run_snapshot={
"checkpoint_ns": "",
"checkpoint": {"id": "ckpt-1", "channel_versions": {}},
"metadata": {},
"pending_writes": [
("task-a", "messages", "value"),
],
},
snapshot_capture_failed=False,
)
# aput succeeded, aput_writes was called but failed
checkpointer.aput.assert_awaited_once()
checkpointer.aput_writes.assert_awaited_once()
def test_agent_factory_supports_app_config_detects_supported_signature():
def factory(*, config, app_config=None):
return (config, app_config)
assert _agent_factory_supports_app_config(factory) is True
def test_build_runtime_context_defaults_to_thread_and_run_id():
ctx = _build_runtime_context("thread-1", "run-1", None)
assert ctx == {"thread_id": "thread-1", "run_id": "run-1"}
def test_build_runtime_context_merges_caller_context():
"""Regression for issue #2677: keys from ``config['context']`` (e.g. ``agent_name``)
must be merged into the Runtime's context so that ``ToolRuntime.context`` — which
is what ``setup_agent`` reads — can see them."""
caller_context = {"agent_name": "my-agent", "is_bootstrap": True, "model_name": "gpt-4"}
ctx = _build_runtime_context("thread-1", "run-1", caller_context)
assert ctx["thread_id"] == "thread-1"
assert ctx["run_id"] == "run-1"
assert ctx["agent_name"] == "my-agent"
assert ctx["is_bootstrap"] is True
assert ctx["model_name"] == "gpt-4"
def test_build_runtime_context_caller_cannot_override_thread_id_or_run_id():
"""A malicious or buggy caller must not be able to overwrite the worker-assigned
``thread_id`` / ``run_id`` by stuffing them into ``config['context']``."""
caller_context = {"thread_id": "spoofed", "run_id": "spoofed", "agent_name": "ok"}
ctx = _build_runtime_context("real-thread", "real-run", caller_context)
assert ctx["thread_id"] == "real-thread"
assert ctx["run_id"] == "real-run"
assert ctx["agent_name"] == "ok"
def test_build_runtime_context_ignores_non_dict_caller_context():
ctx = _build_runtime_context("thread-1", "run-1", "not-a-dict")
assert ctx == {"thread_id": "thread-1", "run_id": "run-1"}
def test_agent_factory_supports_app_config_returns_false_when_signature_lookup_fails(monkeypatch):
class BrokenCallable:
def __call__(self, **kwargs):
return kwargs
monkeypatch.setattr("deerflow.runtime.runs.worker.inspect.signature", lambda _obj: (_ for _ in ()).throw(ValueError("boom")))
assert _agent_factory_supports_app_config(BrokenCallable()) is False