mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 09:25:57 +00:00
fix(runtime): protect sync singleton init and reset (#3413)
* fix(runtime): protect sync singleton init/reset with threading.Lock * fix(runtime): serialize sync singleton init and reset * make format * test(runtime): assert store reset creates new singleton * Apply suggestions from code review Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com> * fix(runtime): load config outside singleton locks * fix(runtime): share checkpointer config loading helper --------- Co-authored-by: GODDiao <diaoshengjia@gmail.com> Co-authored-by: Willem Jiang <willem.jiang@gmail.com> Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -2,7 +2,9 @@
|
||||
|
||||
import sys
|
||||
import tomllib
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from threading import Barrier, Event, Lock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -10,12 +12,14 @@ import pytest
|
||||
import deerflow.config.app_config as app_config_module
|
||||
from deerflow.config.checkpointer_config import (
|
||||
CheckpointerConfig,
|
||||
ensure_config_loaded,
|
||||
get_checkpointer_config,
|
||||
load_checkpointer_config_from_dict,
|
||||
set_checkpointer_config,
|
||||
)
|
||||
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
|
||||
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
|
||||
from deerflow.runtime.store import get_store, reset_store
|
||||
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
|
||||
|
||||
|
||||
@@ -25,10 +29,90 @@ def reset_state():
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
yield
|
||||
app_config_module._app_config = None
|
||||
set_checkpointer_config(None)
|
||||
reset_checkpointer()
|
||||
reset_store()
|
||||
|
||||
|
||||
class _BlockingSingletonContext:
|
||||
def __init__(self, value: object, entered: Event, release: Event, stats: dict[str, object]):
|
||||
self._value = value
|
||||
self._entered = entered
|
||||
self._release = release
|
||||
self._stats = stats
|
||||
|
||||
def __enter__(self):
|
||||
with self._stats["lock"]:
|
||||
self._stats["enters"] += 1
|
||||
self._entered.set()
|
||||
assert self._release.wait(timeout=3), "timed out waiting to release singleton initialization"
|
||||
return self._value
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
with self._stats["lock"]:
|
||||
self._stats["exits"] += 1
|
||||
return False
|
||||
|
||||
|
||||
class _BlockingSingletonFactory:
|
||||
def __init__(self):
|
||||
self.value = object()
|
||||
self.entered = Event()
|
||||
self.release = Event()
|
||||
self.stats = {"enters": 0, "exits": 0, "lock": Lock()}
|
||||
|
||||
def context_manager(self, _config):
|
||||
return _BlockingSingletonContext(self.value, self.entered, self.release, self.stats)
|
||||
|
||||
def enter_count(self) -> int:
|
||||
with self.stats["lock"]:
|
||||
return self.stats["enters"]
|
||||
|
||||
def exit_count(self) -> int:
|
||||
with self.stats["lock"]:
|
||||
return self.stats["exits"]
|
||||
|
||||
|
||||
class _TrackingLock:
|
||||
def __init__(self):
|
||||
self._lock = Lock()
|
||||
self.acquired = Event()
|
||||
|
||||
def acquire(self, *args, **kwargs):
|
||||
acquired = self._lock.acquire(*args, **kwargs)
|
||||
if acquired:
|
||||
self.acquired.set()
|
||||
return acquired
|
||||
|
||||
def release(self):
|
||||
self._lock.release()
|
||||
|
||||
def __enter__(self):
|
||||
self.acquire()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
self.release()
|
||||
return False
|
||||
|
||||
def locked(self) -> bool:
|
||||
return self._lock.locked()
|
||||
|
||||
|
||||
def _call_getter_concurrently(getter, workers: int = 8) -> list[object]:
|
||||
ready = Barrier(workers + 1)
|
||||
|
||||
def worker():
|
||||
ready.wait(timeout=3)
|
||||
return getter()
|
||||
|
||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
futures = [executor.submit(worker) for _ in range(workers)]
|
||||
ready.wait(timeout=3)
|
||||
return [future.result(timeout=3) for future in futures]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -67,6 +151,26 @@ class TestCheckpointerConfig:
|
||||
set_checkpointer_config(None)
|
||||
assert get_checkpointer_config() is None
|
||||
|
||||
def test_ensure_config_loaded_loads_app_config_when_uninitialized(self):
|
||||
def fake_get_app_config():
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
with patch("deerflow.config.app_config.get_app_config", side_effect=fake_get_app_config) as mock_get_app_config:
|
||||
ensure_config_loaded()
|
||||
|
||||
mock_get_app_config.assert_called_once()
|
||||
config = get_checkpointer_config()
|
||||
assert config is not None
|
||||
assert config.type == "memory"
|
||||
|
||||
def test_ensure_config_loaded_skips_explicit_config(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
with patch("deerflow.config.app_config.get_app_config") as mock_get_app_config:
|
||||
ensure_config_loaded()
|
||||
|
||||
mock_get_app_config.assert_not_called()
|
||||
|
||||
def test_invalid_type_raises(self):
|
||||
with pytest.raises(Exception):
|
||||
load_checkpointer_config_from_dict({"type": "unknown"})
|
||||
@@ -118,7 +222,7 @@ class TestGetCheckpointer:
|
||||
"""get_checkpointer should return InMemorySaver when not configured."""
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError):
|
||||
with patch("deerflow.config.app_config.get_app_config", side_effect=FileNotFoundError):
|
||||
cp = get_checkpointer()
|
||||
assert cp is not None
|
||||
assert isinstance(cp, InMemorySaver)
|
||||
@@ -287,6 +391,143 @@ class TestGetCheckpointer:
|
||||
mock_saver_instance.setup.assert_called_once()
|
||||
|
||||
|
||||
class TestSyncSingletonThreadSafety:
|
||||
def test_store_reset_clears_singleton(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
store1 = get_store()
|
||||
reset_store()
|
||||
store2 = get_store()
|
||||
assert store1 is not store2
|
||||
|
||||
def test_concurrent_checkpointer_getter_creates_one_instance(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
factory = _BlockingSingletonFactory()
|
||||
|
||||
with patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager):
|
||||
futures_started = ThreadPoolExecutor(max_workers=1)
|
||||
try:
|
||||
result_future = futures_started.submit(_call_getter_concurrently, get_checkpointer)
|
||||
assert factory.entered.wait(timeout=3)
|
||||
factory.release.wait(timeout=0.05)
|
||||
factory.release.set()
|
||||
results = result_future.result(timeout=3)
|
||||
finally:
|
||||
futures_started.shutdown(wait=True)
|
||||
|
||||
assert all(result is factory.value for result in results)
|
||||
assert factory.enter_count() == 1
|
||||
|
||||
def test_concurrent_store_getter_creates_one_instance(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
factory = _BlockingSingletonFactory()
|
||||
|
||||
with patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager):
|
||||
futures_started = ThreadPoolExecutor(max_workers=1)
|
||||
try:
|
||||
result_future = futures_started.submit(_call_getter_concurrently, get_store)
|
||||
assert factory.entered.wait(timeout=3)
|
||||
factory.release.wait(timeout=0.05)
|
||||
factory.release.set()
|
||||
results = result_future.result(timeout=3)
|
||||
finally:
|
||||
futures_started.shutdown(wait=True)
|
||||
|
||||
assert all(result is factory.value for result in results)
|
||||
assert factory.enter_count() == 1
|
||||
|
||||
def test_checkpointer_loads_config_outside_singleton_lock(self):
|
||||
tracking_lock = _TrackingLock()
|
||||
|
||||
def fake_ensure_config_loaded():
|
||||
assert not tracking_lock.locked()
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.checkpointer.provider._checkpointer_lock", tracking_lock),
|
||||
patch("deerflow.runtime.checkpointer.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded),
|
||||
):
|
||||
checkpointer = get_checkpointer()
|
||||
|
||||
assert checkpointer is not None
|
||||
assert tracking_lock.acquired.is_set()
|
||||
|
||||
def test_store_loads_config_outside_singleton_lock(self):
|
||||
tracking_lock = _TrackingLock()
|
||||
|
||||
def fake_ensure_config_loaded():
|
||||
assert not tracking_lock.locked()
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.store.provider._store_lock", tracking_lock),
|
||||
patch("deerflow.runtime.store.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded),
|
||||
):
|
||||
store = get_store()
|
||||
|
||||
assert store is not None
|
||||
assert tracking_lock.acquired.is_set()
|
||||
|
||||
def test_checkpointer_reset_waits_for_initialization(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
factory = _BlockingSingletonFactory()
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager),
|
||||
ThreadPoolExecutor(max_workers=2) as executor,
|
||||
):
|
||||
get_future = executor.submit(get_checkpointer)
|
||||
assert factory.entered.wait(timeout=3)
|
||||
|
||||
reset_started = Event()
|
||||
|
||||
def reset_worker():
|
||||
reset_started.set()
|
||||
reset_checkpointer()
|
||||
|
||||
reset_future = executor.submit(reset_worker)
|
||||
assert reset_started.wait(timeout=3)
|
||||
factory.release.wait(timeout=0.05)
|
||||
|
||||
assert not reset_future.done()
|
||||
assert factory.exit_count() == 0
|
||||
|
||||
factory.release.set()
|
||||
assert get_future.result(timeout=3) is factory.value
|
||||
reset_future.result(timeout=3)
|
||||
|
||||
assert factory.exit_count() == 1
|
||||
|
||||
def test_store_reset_waits_for_initialization(self):
|
||||
load_checkpointer_config_from_dict({"type": "memory"})
|
||||
factory = _BlockingSingletonFactory()
|
||||
|
||||
with (
|
||||
patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager),
|
||||
ThreadPoolExecutor(max_workers=2) as executor,
|
||||
):
|
||||
get_future = executor.submit(get_store)
|
||||
assert factory.entered.wait(timeout=3)
|
||||
|
||||
reset_started = Event()
|
||||
|
||||
def reset_worker():
|
||||
reset_started.set()
|
||||
reset_store()
|
||||
|
||||
reset_future = executor.submit(reset_worker)
|
||||
assert reset_started.wait(timeout=3)
|
||||
factory.release.wait(timeout=0.05)
|
||||
|
||||
assert not reset_future.done()
|
||||
assert factory.exit_count() == 0
|
||||
|
||||
factory.release.set()
|
||||
assert get_future.result(timeout=3) is factory.value
|
||||
reset_future.result(timeout=3)
|
||||
|
||||
assert factory.exit_count() == 1
|
||||
|
||||
|
||||
class TestAsyncCheckpointer:
|
||||
@pytest.mark.anyio
|
||||
async def test_sqlite_creates_parent_dir_via_to_thread(self):
|
||||
|
||||
Reference in New Issue
Block a user