mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
66d6a6a4e8
* fix: harden run finalization persistence * style: format gateway recovery test * fix: align run repository return types * fix: harden completion recovery follow-up
128 lines
5.6 KiB
Python
128 lines
5.6 KiB
Python
"""Gateway startup recovery for stale persisted runs."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from contextlib import asynccontextmanager
|
|
from types import SimpleNamespace
|
|
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
|
|
import deerflow.runtime as runtime_module
|
|
from app.gateway import deps as gateway_deps
|
|
from deerflow.persistence import engine as engine_module
|
|
from deerflow.persistence import thread_meta as thread_meta_module
|
|
from deerflow.runtime.checkpointer import async_provider as checkpointer_module
|
|
from deerflow.runtime.events import store as event_store_module
|
|
|
|
|
|
@asynccontextmanager
|
|
async def _fake_context(value):
|
|
yield value
|
|
|
|
|
|
class _FakeRunManager:
|
|
"""RunManager double that records startup reconciliation calls."""
|
|
|
|
instances: list[_FakeRunManager] = []
|
|
recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
|
|
latest_by_thread: dict[str, list[SimpleNamespace]] = {}
|
|
|
|
def __init__(self, *, store):
|
|
self.store = store
|
|
self.reconcile_calls: list[dict] = []
|
|
self.list_by_thread_calls: list[dict] = []
|
|
_FakeRunManager.instances.append(self)
|
|
|
|
async def reconcile_orphaned_inflight_runs(self, *, error: str, before: str | None = None):
|
|
self.reconcile_calls.append({"error": error, "before": before})
|
|
return self.recovered_runs
|
|
|
|
async def list_by_thread(self, thread_id: str, *, user_id=None, limit: int = 100):
|
|
self.list_by_thread_calls.append({"thread_id": thread_id, "user_id": user_id, "limit": limit})
|
|
return self.latest_by_thread.get(thread_id, self.recovered_runs[:limit])
|
|
|
|
|
|
class _FakeThreadStore:
|
|
def __init__(self) -> None:
|
|
self.status_updates: list[tuple[str, str, str | None]] = []
|
|
|
|
async def update_status(self, thread_id: str, status: str, *, user_id=None) -> None:
|
|
self.status_updates.append((thread_id, status, user_id))
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_sqlite_runtime_reconciles_orphaned_runs_on_startup(monkeypatch):
|
|
"""SQLite startup should recover stale active runs before serving requests."""
|
|
app = FastAPI()
|
|
config = SimpleNamespace(
|
|
database=SimpleNamespace(backend="sqlite"),
|
|
run_events=SimpleNamespace(backend="memory"),
|
|
)
|
|
thread_store = _FakeThreadStore()
|
|
_FakeRunManager.instances.clear()
|
|
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="run-1", thread_id="thread-1")]
|
|
_FakeRunManager.latest_by_thread = {}
|
|
|
|
async def fake_init_engine_from_config(_database):
|
|
return None
|
|
|
|
async def fake_close_engine():
|
|
return None
|
|
|
|
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
|
|
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
|
|
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
|
|
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
|
|
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
|
|
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
|
|
|
|
async with gateway_deps.langgraph_runtime(app, config):
|
|
pass
|
|
|
|
assert len(_FakeRunManager.instances) == 1
|
|
assert _FakeRunManager.instances[0].reconcile_calls
|
|
assert _FakeRunManager.instances[0].reconcile_calls[0]["error"]
|
|
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
|
|
assert thread_store.status_updates == [("thread-1", "error", None)]
|
|
|
|
|
|
@pytest.mark.anyio
|
|
async def test_sqlite_runtime_does_not_mark_thread_error_when_newer_run_is_success(monkeypatch):
|
|
"""Startup recovery should not let an old orphaned run overwrite a newer terminal thread state."""
|
|
app = FastAPI()
|
|
config = SimpleNamespace(
|
|
database=SimpleNamespace(backend="sqlite"),
|
|
run_events=SimpleNamespace(backend="memory"),
|
|
)
|
|
thread_store = _FakeThreadStore()
|
|
_FakeRunManager.instances.clear()
|
|
_FakeRunManager.recovered_runs = [SimpleNamespace(run_id="old-running", thread_id="thread-1")]
|
|
_FakeRunManager.latest_by_thread = {"thread-1": [SimpleNamespace(run_id="newer-success", thread_id="thread-1", status="success")]}
|
|
|
|
async def fake_init_engine_from_config(_database):
|
|
return None
|
|
|
|
async def fake_close_engine():
|
|
return None
|
|
|
|
monkeypatch.setattr(engine_module, "init_engine_from_config", fake_init_engine_from_config)
|
|
monkeypatch.setattr(engine_module, "get_session_factory", lambda: None)
|
|
monkeypatch.setattr(engine_module, "close_engine", fake_close_engine)
|
|
monkeypatch.setattr(runtime_module, "make_stream_bridge", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(checkpointer_module, "make_checkpointer", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(runtime_module, "make_store", lambda _config: _fake_context(object()))
|
|
monkeypatch.setattr(thread_meta_module, "make_thread_store", lambda _sf, _store: thread_store)
|
|
monkeypatch.setattr(event_store_module, "make_run_event_store", lambda _config: object())
|
|
monkeypatch.setattr(gateway_deps, "RunManager", _FakeRunManager)
|
|
|
|
async with gateway_deps.langgraph_runtime(app, config):
|
|
pass
|
|
|
|
assert len(_FakeRunManager.instances) == 1
|
|
assert _FakeRunManager.instances[0].list_by_thread_calls == [{"thread_id": "thread-1", "user_id": None, "limit": 1}]
|
|
assert thread_store.status_updates == []
|