Files
deer-flow/backend/tests/test_run_manager.py
T
AochenShen99 66d6a6a4e8 fix: harden run finalization persistence (#3155)
* fix: harden run finalization persistence

* style: format gateway recovery test

* fix: align run repository return types

* fix: harden completion recovery follow-up
2026-05-23 00:09:06 +08:00

752 lines
26 KiB
Python

"""Tests for RunManager."""
import logging
import re
import sqlite3
from typing import Any
import pytest
from sqlalchemy.exc import DatabaseError as SQLAlchemyDatabaseError
from deerflow.runtime import DisconnectMode, RunManager, RunStatus
from deerflow.runtime.runs.manager import PersistenceRetryPolicy
from deerflow.runtime.runs.store.memory import MemoryRunStore
ISO_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}")
@pytest.fixture
def manager() -> RunManager:
return RunManager()
class FlakyStatusRunStore(MemoryRunStore):
"""Memory run store that simulates transient SQLite status-write failures."""
def __init__(self, *, status_failures: int) -> None:
super().__init__()
self.status_failures = status_failures
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
if self.status_failures > 0:
self.status_failures -= 1
raise sqlite3.OperationalError("database is locked")
return await super().update_status(run_id, status, error=error)
class MissingRowStatusRunStore(MemoryRunStore):
"""Memory run store that reports a missing row for status updates."""
async def update_status(self, run_id, status, *, error=None):
await super().update_status(run_id, status, error=error)
return False
class PermanentStatusRunStore(MemoryRunStore):
"""Memory run store that simulates a permanent SQLAlchemy write failure."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise SQLAlchemyDatabaseError(
"UPDATE runs SET status = :status WHERE run_id = :run_id",
{"status": status, "run_id": run_id},
sqlite3.DatabaseError("no such table: runs"),
)
class FailingStatusRunStore(MemoryRunStore):
"""Memory run store that always fails status updates."""
def __init__(self) -> None:
super().__init__()
self.status_update_attempts = 0
async def update_status(self, run_id, status, *, error=None):
self.status_update_attempts += 1
raise sqlite3.OperationalError("database is locked")
class MissingCompletionRunStore(MemoryRunStore):
"""Memory run store that reports one missing row for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
if self.completion_update_attempts == 1:
return False
return await super().update_run_completion(run_id, status=status, **kwargs)
class AlwaysMissingCompletionRunStore(MemoryRunStore):
"""Memory run store that keeps reporting missing rows for completion updates."""
def __init__(self) -> None:
super().__init__()
self.completion_update_attempts = 0
async def update_run_completion(self, run_id, *, status, **kwargs):
self.completion_update_attempts += 1
return False
async def _stored_statuses(store: MemoryRunStore, *run_ids: str) -> dict[str, Any]:
rows = {}
for run_id in run_ids:
row = await store.get(run_id)
rows[run_id] = row["status"] if row else None
return rows
@pytest.mark.anyio
async def test_create_and_get(manager: RunManager):
"""Created run should be retrievable with new fields."""
record = await manager.create(
"thread-1",
"lead_agent",
metadata={"key": "val"},
kwargs={"input": {}},
multitask_strategy="reject",
)
assert record.status == RunStatus.pending
assert record.thread_id == "thread-1"
assert record.assistant_id == "lead_agent"
assert record.metadata == {"key": "val"}
assert record.kwargs == {"input": {}}
assert record.multitask_strategy == "reject"
assert ISO_RE.match(record.created_at)
assert ISO_RE.match(record.updated_at)
fetched = await manager.get(record.run_id)
assert fetched is record
@pytest.mark.anyio
async def test_status_transitions(manager: RunManager):
"""Status should transition pending -> running -> success."""
record = await manager.create("thread-1")
assert record.status == RunStatus.pending
await manager.set_status(record.run_id, RunStatus.running)
assert record.status == RunStatus.running
assert ISO_RE.match(record.updated_at)
await manager.set_status(record.run_id, RunStatus.success)
assert record.status == RunStatus.success
@pytest.mark.anyio
async def test_cancel(manager: RunManager):
"""Cancel should set abort_event and transition to interrupted."""
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
cancelled = await manager.cancel(record.run_id)
assert cancelled is True
assert record.abort_event.is_set()
assert record.status == RunStatus.interrupted
@pytest.mark.anyio
async def test_cancel_persists_interrupted_status_to_store():
"""Cancel should persist interrupted status to the backing store."""
store = MemoryRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
cancelled = await manager.cancel(record.run_id)
stored = await store.get(record.run_id)
assert cancelled is True
assert stored is not None
assert stored["status"] == "interrupted"
@pytest.mark.anyio
async def test_status_persistence_retries_transient_sqlite_lock():
"""Transient SQLite lock errors should not leave a final status stale."""
store = FlakyStatusRunStore(status_failures=2)
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert store.status_update_attempts >= 4
@pytest.mark.anyio
async def test_status_persistence_recreates_missing_store_row():
"""A final status update should recreate a run row if initial persistence was lost."""
store = MissingRowStatusRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await store.delete(record.run_id)
await manager.set_status(record.run_id, RunStatus.error, error="boom")
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "error"
assert stored["error"] == "boom"
@pytest.mark.anyio
async def test_status_persistence_does_not_retry_permanent_sqlalchemy_errors():
"""Permanent SQLAlchemy failures should not be retried as SQLite pressure."""
store = PermanentStatusRunStore()
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=5, initial_delay=0),
)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.error, error="boom")
assert store.status_update_attempts == 1
@pytest.mark.anyio
async def test_completion_persistence_recreates_missing_store_row():
"""Completion updates should recreate a missing row and persist final counters."""
store = MissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
await manager.set_status(record.run_id, RunStatus.success)
await store.delete(record.run_id)
await manager.update_run_completion(
record.run_id,
status="success",
total_tokens=42,
llm_call_count=2,
last_ai_message="done",
)
stored = await store.get(record.run_id)
assert stored is not None
assert stored["status"] == "success"
assert stored["total_tokens"] == 42
assert stored["llm_call_count"] == 2
assert stored["last_ai_message"] == "done"
assert store.completion_update_attempts == 2
@pytest.mark.anyio
async def test_completion_persistence_warns_when_recreated_row_still_missing(caplog):
"""A second zero-row completion update after recreation should not be silent."""
store = AlwaysMissingCompletionRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.success)
caplog.set_level(logging.WARNING, logger="deerflow.runtime.runs.manager")
await manager.update_run_completion(record.run_id, status="success", total_tokens=42)
assert store.completion_update_attempts == 2
assert "affected no rows after row recreation" in caplog.text
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_marks_stale_rows_error():
"""Startup recovery should turn persisted active rows into explicit errors."""
store = MemoryRunStore()
await store.put("pending-run", thread_id="thread-1", status="pending", created_at="2026-01-01T00:00:00+00:00")
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:01+00:00")
await store.put("success-run", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:02+00:00")
manager = RunManager(store=store)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:02+00:00",
)
assert {record.run_id for record in recovered} == {"pending-run", "running-run"}
assert await _stored_statuses(store, "pending-run", "running-run", "success-run") == {
"pending-run": "error",
"running-run": "error",
"success-run": "success",
}
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_live_local_run():
"""Startup recovery should not mark an active row orphaned when this worker owns it."""
store = MemoryRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.running)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
)
stored = await store.get(record.run_id)
assert recovered == []
assert stored["status"] == "running"
@pytest.mark.anyio
async def test_reconcile_orphaned_inflight_runs_skips_rows_when_error_status_is_not_persisted():
"""Startup recovery must not report a row as recovered if the error update failed."""
store = FailingStatusRunStore()
await store.put("running-run", thread_id="thread-1", status="running", created_at="2026-01-01T00:00:00+00:00")
manager = RunManager(
store=store,
persistence_retry_policy=PersistenceRetryPolicy(max_attempts=2, initial_delay=0),
)
recovered = await manager.reconcile_orphaned_inflight_runs(
error="Gateway restarted before this run reached a durable final state.",
before="2026-01-01T00:00:01+00:00",
)
stored = await store.get("running-run")
assert recovered == []
assert stored["status"] == "running"
assert store.status_update_attempts == 2
@pytest.mark.anyio
async def test_cancel_not_inflight(manager: RunManager):
"""Cancelling a completed run should return False."""
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.success)
cancelled = await manager.cancel(record.run_id)
assert cancelled is False
@pytest.mark.anyio
async def test_list_by_thread(manager: RunManager):
"""Same thread should return multiple runs."""
r1 = await manager.create("thread-1")
r2 = await manager.create("thread-1")
await manager.create("thread-2")
runs = await manager.list_by_thread("thread-1")
assert len(runs) == 2
# Newest first: r2 was created after r1.
assert runs[0].run_id == r2.run_id
assert runs[1].run_id == r1.run_id
@pytest.mark.anyio
async def test_list_by_thread_is_stable_when_timestamps_tie(manager: RunManager, monkeypatch: pytest.MonkeyPatch):
"""Ordering should be stable (insertion order) even when timestamps tie."""
monkeypatch.setattr("deerflow.runtime.runs.manager._now_iso", lambda: "2026-01-01T00:00:00+00:00")
r1 = await manager.create("thread-1")
r2 = await manager.create("thread-1")
runs = await manager.list_by_thread("thread-1")
assert [run.run_id for run in runs] == [r1.run_id, r2.run_id]
@pytest.mark.anyio
async def test_has_inflight(manager: RunManager):
"""has_inflight should be True when a run is pending or running."""
record = await manager.create("thread-1")
assert await manager.has_inflight("thread-1") is True
await manager.set_status(record.run_id, RunStatus.success)
assert await manager.has_inflight("thread-1") is False
@pytest.mark.anyio
async def test_cleanup(manager: RunManager):
"""After cleanup, the run should be gone."""
record = await manager.create("thread-1")
run_id = record.run_id
await manager.cleanup(run_id, delay=0)
assert await manager.get(run_id) is None
@pytest.mark.anyio
async def test_set_status_with_error(manager: RunManager):
"""Error message should be stored on the record."""
record = await manager.create("thread-1")
await manager.set_status(record.run_id, RunStatus.error, error="Something went wrong")
assert record.status == RunStatus.error
assert record.error == "Something went wrong"
@pytest.mark.anyio
async def test_get_nonexistent(manager: RunManager):
"""Getting a nonexistent run should return None."""
assert await manager.get("does-not-exist") is None
@pytest.mark.anyio
async def test_get_hydrates_store_only_run():
"""Store-only runs should be readable after process restart."""
store = MemoryRunStore()
await store.put(
"run-store-only",
thread_id="thread-1",
assistant_id="lead_agent",
status="success",
multitask_strategy="reject",
metadata={"source": "store"},
kwargs={"input": "value"},
created_at="2026-01-01T00:00:00+00:00",
model_name="model-a",
)
manager = RunManager(store=store)
record = await manager.get("run-store-only")
assert record is not None
assert record.run_id == "run-store-only"
assert record.thread_id == "thread-1"
assert record.assistant_id == "lead_agent"
assert record.status == RunStatus.success
assert record.on_disconnect == DisconnectMode.cancel
assert record.metadata == {"source": "store"}
assert record.kwargs == {"input": "value"}
assert record.model_name == "model-a"
assert record.task is None
assert record.store_only is True
@pytest.mark.anyio
async def test_get_hydrates_run_with_null_enum_fields():
"""Rows with NULL status/on_disconnect must hydrate with safe defaults, not raise."""
store = MemoryRunStore()
# Simulate a SQL row where the nullable status column is NULL
await store.put(
"run-null-status",
thread_id="thread-1",
status=None,
created_at="2026-01-01T00:00:00+00:00",
)
manager = RunManager(store=store)
record = await manager.get("run-null-status")
assert record is not None
assert record.status == RunStatus.pending
assert record.on_disconnect == DisconnectMode.cancel
assert record.store_only is True
@pytest.mark.anyio
async def test_list_by_thread_hydrates_run_with_null_enum_fields():
"""list_by_thread must not skip rows with NULL status; applies safe defaults."""
store = MemoryRunStore()
await store.put(
"run-null-status-list",
thread_id="thread-null",
status=None,
created_at="2026-01-01T00:00:00+00:00",
)
manager = RunManager(store=store)
runs = await manager.list_by_thread("thread-null")
assert len(runs) == 1
assert runs[0].run_id == "run-null-status-list"
assert runs[0].status == RunStatus.pending
assert runs[0].on_disconnect == DisconnectMode.cancel
@pytest.mark.anyio
async def test_create_record_is_not_store_only(manager: RunManager):
"""In-memory records created via create() must have store_only=False."""
record = await manager.create("thread-1")
assert record.store_only is False
@pytest.mark.anyio
async def test_get_prefers_in_memory_record_over_store():
"""In-memory records retain task/control state when store has same run."""
store = MemoryRunStore()
manager = RunManager(store=store)
record = await manager.create("thread-1")
await store.update_status(record.run_id, "success")
fetched = await manager.get(record.run_id)
assert fetched is record
assert fetched.status == RunStatus.pending
@pytest.mark.anyio
async def test_list_by_thread_merges_store_runs_newest_first():
"""list_by_thread should merge memory and store rows with memory precedence."""
store = MemoryRunStore()
await store.put("old-store", thread_id="thread-1", status="success", created_at="2026-01-01T00:00:00+00:00")
await store.put("other-thread", thread_id="thread-2", status="success", created_at="2026-01-03T00:00:00+00:00")
manager = RunManager(store=store)
memory_record = await manager.create("thread-1")
runs = await manager.list_by_thread("thread-1")
assert [run.run_id for run in runs] == [memory_record.run_id, "old-store"]
assert runs[0] is memory_record
@pytest.mark.anyio
async def test_create_defaults(manager: RunManager):
"""Create with no optional args should use defaults."""
record = await manager.create("thread-1")
assert record.metadata == {}
assert record.kwargs == {}
assert record.multitask_strategy == "reject"
assert record.assistant_id is None
@pytest.mark.anyio
async def test_model_name_create_or_reject():
"""create_or_reject should accept and persist model_name."""
from deerflow.runtime.runs.schemas import DisconnectMode
store = MemoryRunStore()
mgr = RunManager(store=store)
record = await mgr.create_or_reject(
"thread-1",
assistant_id="lead_agent",
on_disconnect=DisconnectMode.cancel,
metadata={"key": "val"},
kwargs={"input": {}},
multitask_strategy="reject",
model_name="anthropic.claude-sonnet-4-20250514-v1:0",
)
assert record.model_name == "anthropic.claude-sonnet-4-20250514-v1:0"
assert record.status == RunStatus.pending
# Verify model_name was persisted to store
stored = await store.get(record.run_id)
assert stored is not None
assert stored["model_name"] == "anthropic.claude-sonnet-4-20250514-v1:0"
# Verify retrieval returns the model_name via in-memory record
fetched = await mgr.get(record.run_id)
assert fetched is not None
assert fetched.model_name == "anthropic.claude-sonnet-4-20250514-v1:0"
@pytest.mark.anyio
async def test_create_or_reject_interrupt_persists_interrupted_status_to_store():
"""interrupt strategy should persist interrupted status for old runs."""
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
new = await manager.create_or_reject("thread-1", multitask_strategy="interrupt")
stored_old = await store.get(old.run_id)
assert new.run_id != old.run_id
assert old.status == RunStatus.interrupted
assert stored_old is not None
assert stored_old["status"] == "interrupted"
@pytest.mark.anyio
async def test_create_or_reject_rollback_persists_interrupted_status_to_store():
"""rollback strategy should persist interrupted status for old runs."""
store = MemoryRunStore()
manager = RunManager(store=store)
old = await manager.create("thread-1")
await manager.set_status(old.run_id, RunStatus.running)
new = await manager.create_or_reject("thread-1", multitask_strategy="rollback")
stored_old = await store.get(old.run_id)
assert new.run_id != old.run_id
assert old.status == RunStatus.interrupted
assert stored_old is not None
assert stored_old["status"] == "interrupted"
@pytest.mark.anyio
async def test_model_name_default_is_none():
"""create_or_reject without model_name should default to None."""
from deerflow.runtime.runs.schemas import DisconnectMode
store = MemoryRunStore()
mgr = RunManager(store=store)
record = await mgr.create_or_reject(
"thread-1",
on_disconnect=DisconnectMode.cancel,
model_name=None,
)
assert record.model_name is None
stored = await store.get(record.run_id)
assert stored["model_name"] is None
# ---------------------------------------------------------------------------
# Store fallback tests (simulates gateway restart scenario)
# ---------------------------------------------------------------------------
@pytest.fixture
def manager_with_store() -> RunManager:
"""RunManager backed by a MemoryRunStore."""
return RunManager(store=MemoryRunStore())
@pytest.mark.anyio
async def test_list_by_thread_returns_store_records_after_restart(manager_with_store: RunManager):
"""After in-memory state is cleared (simulating restart), list_by_thread
should still return runs from the persistent store."""
mgr = manager_with_store
r1 = await mgr.create("thread-1", "agent-1")
await mgr.set_status(r1.run_id, RunStatus.success)
r2 = await mgr.create("thread-1", "agent-2")
await mgr.set_status(r2.run_id, RunStatus.error, error="boom")
# Clear in-memory dict to simulate a restart
mgr._runs.clear()
runs = await mgr.list_by_thread("thread-1")
assert len(runs) == 2
statuses = {r.run_id: r.status for r in runs}
assert statuses[r1.run_id] == RunStatus.success
assert statuses[r2.run_id] == RunStatus.error
# Verify other fields survive the round-trip
for r in runs:
assert r.thread_id == "thread-1"
assert ISO_RE.match(r.created_at)
@pytest.mark.anyio
async def test_list_by_thread_merges_in_memory_and_store(manager_with_store: RunManager):
"""In-memory runs should be included alongside store-only records."""
mgr = manager_with_store
# Create a run and let it complete (will be in both memory and store)
r1 = await mgr.create("thread-1")
await mgr.set_status(r1.run_id, RunStatus.success)
# Simulate restart: clear memory, then create a new in-memory run
mgr._runs.clear()
r2 = await mgr.create("thread-1")
runs = await mgr.list_by_thread("thread-1")
assert len(runs) == 2
run_ids = {r.run_id for r in runs}
assert r1.run_id in run_ids
assert r2.run_id in run_ids
# r2 should be the in-memory record (has live state)
r2_record = next(r for r in runs if r.run_id == r2.run_id)
assert r2_record is r2 # same object reference
@pytest.mark.anyio
async def test_list_by_thread_no_store():
"""Without a store, list_by_thread should only return in-memory runs."""
mgr = RunManager()
await mgr.create("thread-1")
mgr._runs.clear()
runs = await mgr.list_by_thread("thread-1")
assert runs == []
@pytest.mark.anyio
async def test_aget_returns_in_memory_record(manager_with_store: RunManager):
"""aget should return the in-memory record when available."""
mgr = manager_with_store
r1 = await mgr.create("thread-1", "agent-1")
result = await mgr.aget(r1.run_id)
assert result is r1 # same object
@pytest.mark.anyio
async def test_aget_falls_back_to_store(manager_with_store: RunManager):
"""aget should return a record from the store when not in memory."""
mgr = manager_with_store
r1 = await mgr.create("thread-1", "agent-1")
await mgr.set_status(r1.run_id, RunStatus.success)
mgr._runs.clear()
result = await mgr.aget(r1.run_id)
assert result is not None
assert result.run_id == r1.run_id
assert result.status == RunStatus.success
assert result.thread_id == "thread-1"
assert result.assistant_id == "agent-1"
@pytest.mark.anyio
async def test_aget_falls_back_to_store_with_user_filter():
"""aget should honor user_id when reading store-only records."""
store = MemoryRunStore()
await store.put("run-1", thread_id="thread-1", user_id="user-1", status="success")
mgr = RunManager(store=store)
allowed = await mgr.aget("run-1", user_id="user-1")
denied = await mgr.aget("run-1", user_id="user-2")
assert allowed is not None
assert denied is None
@pytest.mark.anyio
async def test_aget_returns_none_for_unknown(manager_with_store: RunManager):
"""aget should return None for a run ID that doesn't exist anywhere."""
result = await manager_with_store.aget("nonexistent-run-id")
assert result is None
@pytest.mark.anyio
async def test_aget_store_failure_is_graceful():
"""If the store raises, aget should return None instead of propagating."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
store.get = AsyncMock(side_effect=RuntimeError("db down"))
mgr = RunManager(store=store)
result = await mgr.aget("some-id")
assert result is None
@pytest.mark.anyio
async def test_list_by_thread_store_failure_is_graceful():
"""If the store raises, list_by_thread should return only in-memory runs."""
from unittest.mock import AsyncMock
store = MemoryRunStore()
store.list_by_thread = AsyncMock(side_effect=RuntimeError("db down"))
mgr = RunManager(store=store)
r1 = await mgr.create("thread-1")
runs = await mgr.list_by_thread("thread-1")
assert len(runs) == 1
assert runs[0].run_id == r1.run_id
@pytest.mark.anyio
async def test_list_by_thread_falls_back_to_store_with_user_filter():
"""list_by_thread should return only the requesting user's store records."""
store = MemoryRunStore()
await store.put("run-1", thread_id="thread-1", user_id="user-1", status="success")
await store.put("run-2", thread_id="thread-1", user_id="user-2", status="success")
mgr = RunManager(store=store)
runs = await mgr.list_by_thread("thread-1", user_id="user-1")
assert [r.run_id for r in runs] == ["run-1"]