fix(runtime): protect sync singleton init and reset (#3413)

* fix(runtime): protect sync singleton init/reset with threading.Lock

* fix(runtime): serialize sync singleton init and reset

* make format

* test(runtime): assert store reset creates new singleton

* Apply suggestions from code review

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* fix(runtime): load config outside singleton locks

* fix(runtime): share checkpointer config loading helper

---------

Co-authored-by: GODDiao <diaoshengjia@gmail.com>
Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Nan Gao
2026-06-08 02:38:36 +02:00
committed by GitHub
parent 3b4c9ff733
commit f725a963d5
4 changed files with 313 additions and 56 deletions
@@ -41,6 +41,20 @@ def set_checkpointer_config(config: CheckpointerConfig | None) -> None:
_checkpointer_config = config _checkpointer_config = config
def ensure_config_loaded() -> None:
"""Lazily load app config when checkpointer config has not been initialized."""
from deerflow.config.app_config import _app_config, get_app_config
config = get_checkpointer_config()
if config is not None or _app_config is not None:
return
try:
get_app_config()
except FileNotFoundError:
pass
def load_checkpointer_config_from_dict(config_dict: dict | None) -> None: def load_checkpointer_config_from_dict(config_dict: dict | None) -> None:
"""Load checkpointer configuration from a dictionary.""" """Load checkpointer configuration from a dictionary."""
global _checkpointer_config global _checkpointer_config
@@ -21,12 +21,13 @@ from __future__ import annotations
import contextlib import contextlib
import logging import logging
import threading
from collections.abc import Iterator from collections.abc import Iterator
from langgraph.types import Checkpointer from langgraph.types import Checkpointer
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import CheckpointerConfig from deerflow.config.checkpointer_config import CheckpointerConfig, ensure_config_loaded
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -100,6 +101,7 @@ def _sync_checkpointer_cm(config: CheckpointerConfig) -> Iterator[Checkpointer]:
_checkpointer: Checkpointer | None = None _checkpointer: Checkpointer | None = None
_checkpointer_ctx = None # open context manager keeping the connection alive _checkpointer_ctx = None # open context manager keeping the connection alive
_checkpointer_lock = threading.Lock()
def get_checkpointer() -> Checkpointer: def get_checkpointer() -> Checkpointer:
@@ -116,34 +118,29 @@ def get_checkpointer() -> Checkpointer:
if _checkpointer is not None: if _checkpointer is not None:
return _checkpointer return _checkpointer
# Ensure app config is loaded before checking checkpointer config # Config loading can reset both persistence singletons. Keep it outside
# This prevents returning InMemorySaver when config.yaml actually has a checkpointer section # this provider lock to avoid cross-provider lock-order inversion.
# but hasn't been loaded yet ensure_config_loaded()
from deerflow.config.app_config import _app_config
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config() with _checkpointer_lock:
if _checkpointer is not None:
return _checkpointer
from deerflow.config.checkpointer_config import get_checkpointer_config
if config is None and _app_config is None:
# Only load app config lazily when neither the app config nor an explicit
# checkpointer config has been initialized yet. This keeps tests that
# intentionally set the global checkpointer config isolated from any
# ambient config.yaml on disk.
try:
get_app_config()
except FileNotFoundError:
# In test environments without config.yaml, this is expected.
pass
config = get_checkpointer_config() config = get_checkpointer_config()
if config is None:
from langgraph.checkpoint.memory import InMemorySaver
logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)") if config is None:
_checkpointer = InMemorySaver() from langgraph.checkpoint.memory import InMemorySaver
return _checkpointer
_checkpointer_ctx = _sync_checkpointer_cm(config) logger.info("Checkpointer: using InMemorySaver (in-process, not persistent)")
_checkpointer = _checkpointer_ctx.__enter__() _checkpointer = InMemorySaver()
return _checkpointer
checkpointer_ctx = _sync_checkpointer_cm(config)
checkpointer = checkpointer_ctx.__enter__()
_checkpointer_ctx = checkpointer_ctx
_checkpointer = checkpointer
return _checkpointer return _checkpointer
@@ -155,13 +152,14 @@ def reset_checkpointer() -> None:
Useful in tests or after a configuration change. Useful in tests or after a configuration change.
""" """
global _checkpointer, _checkpointer_ctx global _checkpointer, _checkpointer_ctx
if _checkpointer_ctx is not None: with _checkpointer_lock:
try: if _checkpointer_ctx is not None:
_checkpointer_ctx.__exit__(None, None, None) try:
except Exception: _checkpointer_ctx.__exit__(None, None, None)
logger.warning("Error during checkpointer cleanup", exc_info=True) except Exception:
_checkpointer_ctx = None logger.warning("Error during checkpointer cleanup", exc_info=True)
_checkpointer = None _checkpointer_ctx = None
_checkpointer = None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -22,11 +22,13 @@ from __future__ import annotations
import contextlib import contextlib
import logging import logging
import threading
from collections.abc import Iterator from collections.abc import Iterator
from langgraph.store.base import BaseStore from langgraph.store.base import BaseStore
from deerflow.config.app_config import get_app_config from deerflow.config.app_config import get_app_config
from deerflow.config.checkpointer_config import ensure_config_loaded
from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str from deerflow.runtime.store._sqlite_utils import ensure_sqlite_parent_dir, resolve_sqlite_conn_str
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -100,6 +102,7 @@ def _sync_store_cm(config) -> Iterator[BaseStore]:
_store: BaseStore | None = None _store: BaseStore | None = None
_store_ctx = None # open context manager keeping the connection alive _store_ctx = None # open context manager keeping the connection alive
_store_lock = threading.Lock()
def get_store() -> BaseStore: def get_store() -> BaseStore:
@@ -117,29 +120,29 @@ def get_store() -> BaseStore:
if _store is not None: if _store is not None:
return _store return _store
# Lazily load app config, mirroring the checkpointer singleton pattern so # Config loading can reset both persistence singletons. Keep it outside
# that tests that set the global checkpointer config explicitly remain isolated. # this provider lock to avoid cross-provider lock-order inversion.
from deerflow.config.app_config import _app_config ensure_config_loaded()
from deerflow.config.checkpointer_config import get_checkpointer_config
config = get_checkpointer_config() with _store_lock:
if _store is not None:
return _store
from deerflow.config.checkpointer_config import get_checkpointer_config
if config is None and _app_config is None:
try:
get_app_config()
except FileNotFoundError:
pass
config = get_checkpointer_config() config = get_checkpointer_config()
if config is None: if config is None:
from langgraph.store.memory import InMemoryStore from langgraph.store.memory import InMemoryStore
logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.") logger.warning("No 'checkpointer' section in config.yaml — using InMemoryStore for the store. Thread list will be lost on server restart. Configure a sqlite or postgres backend for persistence.")
_store = InMemoryStore() _store = InMemoryStore()
return _store return _store
_store_ctx = _sync_store_cm(config) store_ctx = _sync_store_cm(config)
_store = _store_ctx.__enter__() store = store_ctx.__enter__()
_store_ctx = store_ctx
_store = store
return _store return _store
@@ -150,13 +153,14 @@ def reset_store() -> None:
Useful in tests or after a configuration change. Useful in tests or after a configuration change.
""" """
global _store, _store_ctx global _store, _store_ctx
if _store_ctx is not None: with _store_lock:
try: if _store_ctx is not None:
_store_ctx.__exit__(None, None, None) try:
except Exception: _store_ctx.__exit__(None, None, None)
logger.warning("Error during store cleanup", exc_info=True) except Exception:
_store_ctx = None logger.warning("Error during store cleanup", exc_info=True)
_store = None _store_ctx = None
_store = None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
+242 -1
View File
@@ -2,7 +2,9 @@
import sys import sys
import tomllib import tomllib
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path from pathlib import Path
from threading import Barrier, Event, Lock
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@@ -10,12 +12,14 @@ import pytest
import deerflow.config.app_config as app_config_module import deerflow.config.app_config as app_config_module
from deerflow.config.checkpointer_config import ( from deerflow.config.checkpointer_config import (
CheckpointerConfig, CheckpointerConfig,
ensure_config_loaded,
get_checkpointer_config, get_checkpointer_config,
load_checkpointer_config_from_dict, load_checkpointer_config_from_dict,
set_checkpointer_config, set_checkpointer_config,
) )
from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer from deerflow.runtime.checkpointer import get_checkpointer, reset_checkpointer
from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL from deerflow.runtime.checkpointer.provider import POSTGRES_INSTALL
from deerflow.runtime.store import get_store, reset_store
from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL from deerflow.runtime.store.provider import POSTGRES_STORE_INSTALL
@@ -25,10 +29,90 @@ def reset_state():
app_config_module._app_config = None app_config_module._app_config = None
set_checkpointer_config(None) set_checkpointer_config(None)
reset_checkpointer() reset_checkpointer()
reset_store()
yield yield
app_config_module._app_config = None app_config_module._app_config = None
set_checkpointer_config(None) set_checkpointer_config(None)
reset_checkpointer() reset_checkpointer()
reset_store()
class _BlockingSingletonContext:
def __init__(self, value: object, entered: Event, release: Event, stats: dict[str, object]):
self._value = value
self._entered = entered
self._release = release
self._stats = stats
def __enter__(self):
with self._stats["lock"]:
self._stats["enters"] += 1
self._entered.set()
assert self._release.wait(timeout=3), "timed out waiting to release singleton initialization"
return self._value
def __exit__(self, exc_type, exc, tb):
with self._stats["lock"]:
self._stats["exits"] += 1
return False
class _BlockingSingletonFactory:
def __init__(self):
self.value = object()
self.entered = Event()
self.release = Event()
self.stats = {"enters": 0, "exits": 0, "lock": Lock()}
def context_manager(self, _config):
return _BlockingSingletonContext(self.value, self.entered, self.release, self.stats)
def enter_count(self) -> int:
with self.stats["lock"]:
return self.stats["enters"]
def exit_count(self) -> int:
with self.stats["lock"]:
return self.stats["exits"]
class _TrackingLock:
def __init__(self):
self._lock = Lock()
self.acquired = Event()
def acquire(self, *args, **kwargs):
acquired = self._lock.acquire(*args, **kwargs)
if acquired:
self.acquired.set()
return acquired
def release(self):
self._lock.release()
def __enter__(self):
self.acquire()
return self
def __exit__(self, exc_type, exc, tb):
self.release()
return False
def locked(self) -> bool:
return self._lock.locked()
def _call_getter_concurrently(getter, workers: int = 8) -> list[object]:
ready = Barrier(workers + 1)
def worker():
ready.wait(timeout=3)
return getter()
with ThreadPoolExecutor(max_workers=workers) as executor:
futures = [executor.submit(worker) for _ in range(workers)]
ready.wait(timeout=3)
return [future.result(timeout=3) for future in futures]
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -67,6 +151,26 @@ class TestCheckpointerConfig:
set_checkpointer_config(None) set_checkpointer_config(None)
assert get_checkpointer_config() is None assert get_checkpointer_config() is None
def test_ensure_config_loaded_loads_app_config_when_uninitialized(self):
def fake_get_app_config():
load_checkpointer_config_from_dict({"type": "memory"})
with patch("deerflow.config.app_config.get_app_config", side_effect=fake_get_app_config) as mock_get_app_config:
ensure_config_loaded()
mock_get_app_config.assert_called_once()
config = get_checkpointer_config()
assert config is not None
assert config.type == "memory"
def test_ensure_config_loaded_skips_explicit_config(self):
load_checkpointer_config_from_dict({"type": "memory"})
with patch("deerflow.config.app_config.get_app_config") as mock_get_app_config:
ensure_config_loaded()
mock_get_app_config.assert_not_called()
def test_invalid_type_raises(self): def test_invalid_type_raises(self):
with pytest.raises(Exception): with pytest.raises(Exception):
load_checkpointer_config_from_dict({"type": "unknown"}) load_checkpointer_config_from_dict({"type": "unknown"})
@@ -118,7 +222,7 @@ class TestGetCheckpointer:
"""get_checkpointer should return InMemorySaver when not configured.""" """get_checkpointer should return InMemorySaver when not configured."""
from langgraph.checkpoint.memory import InMemorySaver from langgraph.checkpoint.memory import InMemorySaver
with patch("deerflow.runtime.checkpointer.provider.get_app_config", side_effect=FileNotFoundError): with patch("deerflow.config.app_config.get_app_config", side_effect=FileNotFoundError):
cp = get_checkpointer() cp = get_checkpointer()
assert cp is not None assert cp is not None
assert isinstance(cp, InMemorySaver) assert isinstance(cp, InMemorySaver)
@@ -287,6 +391,143 @@ class TestGetCheckpointer:
mock_saver_instance.setup.assert_called_once() mock_saver_instance.setup.assert_called_once()
class TestSyncSingletonThreadSafety:
def test_store_reset_clears_singleton(self):
load_checkpointer_config_from_dict({"type": "memory"})
store1 = get_store()
reset_store()
store2 = get_store()
assert store1 is not store2
def test_concurrent_checkpointer_getter_creates_one_instance(self):
load_checkpointer_config_from_dict({"type": "memory"})
factory = _BlockingSingletonFactory()
with patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager):
futures_started = ThreadPoolExecutor(max_workers=1)
try:
result_future = futures_started.submit(_call_getter_concurrently, get_checkpointer)
assert factory.entered.wait(timeout=3)
factory.release.wait(timeout=0.05)
factory.release.set()
results = result_future.result(timeout=3)
finally:
futures_started.shutdown(wait=True)
assert all(result is factory.value for result in results)
assert factory.enter_count() == 1
def test_concurrent_store_getter_creates_one_instance(self):
load_checkpointer_config_from_dict({"type": "memory"})
factory = _BlockingSingletonFactory()
with patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager):
futures_started = ThreadPoolExecutor(max_workers=1)
try:
result_future = futures_started.submit(_call_getter_concurrently, get_store)
assert factory.entered.wait(timeout=3)
factory.release.wait(timeout=0.05)
factory.release.set()
results = result_future.result(timeout=3)
finally:
futures_started.shutdown(wait=True)
assert all(result is factory.value for result in results)
assert factory.enter_count() == 1
def test_checkpointer_loads_config_outside_singleton_lock(self):
tracking_lock = _TrackingLock()
def fake_ensure_config_loaded():
assert not tracking_lock.locked()
load_checkpointer_config_from_dict({"type": "memory"})
with (
patch("deerflow.runtime.checkpointer.provider._checkpointer_lock", tracking_lock),
patch("deerflow.runtime.checkpointer.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded),
):
checkpointer = get_checkpointer()
assert checkpointer is not None
assert tracking_lock.acquired.is_set()
def test_store_loads_config_outside_singleton_lock(self):
tracking_lock = _TrackingLock()
def fake_ensure_config_loaded():
assert not tracking_lock.locked()
load_checkpointer_config_from_dict({"type": "memory"})
with (
patch("deerflow.runtime.store.provider._store_lock", tracking_lock),
patch("deerflow.runtime.store.provider.ensure_config_loaded", side_effect=fake_ensure_config_loaded),
):
store = get_store()
assert store is not None
assert tracking_lock.acquired.is_set()
def test_checkpointer_reset_waits_for_initialization(self):
load_checkpointer_config_from_dict({"type": "memory"})
factory = _BlockingSingletonFactory()
with (
patch("deerflow.runtime.checkpointer.provider._sync_checkpointer_cm", side_effect=factory.context_manager),
ThreadPoolExecutor(max_workers=2) as executor,
):
get_future = executor.submit(get_checkpointer)
assert factory.entered.wait(timeout=3)
reset_started = Event()
def reset_worker():
reset_started.set()
reset_checkpointer()
reset_future = executor.submit(reset_worker)
assert reset_started.wait(timeout=3)
factory.release.wait(timeout=0.05)
assert not reset_future.done()
assert factory.exit_count() == 0
factory.release.set()
assert get_future.result(timeout=3) is factory.value
reset_future.result(timeout=3)
assert factory.exit_count() == 1
def test_store_reset_waits_for_initialization(self):
load_checkpointer_config_from_dict({"type": "memory"})
factory = _BlockingSingletonFactory()
with (
patch("deerflow.runtime.store.provider._sync_store_cm", side_effect=factory.context_manager),
ThreadPoolExecutor(max_workers=2) as executor,
):
get_future = executor.submit(get_store)
assert factory.entered.wait(timeout=3)
reset_started = Event()
def reset_worker():
reset_started.set()
reset_store()
reset_future = executor.submit(reset_worker)
assert reset_started.wait(timeout=3)
factory.release.wait(timeout=0.05)
assert not reset_future.done()
assert factory.exit_count() == 0
factory.release.set()
assert get_future.result(timeout=3) is factory.value
reset_future.result(timeout=3)
assert factory.exit_count() == 1
class TestAsyncCheckpointer: class TestAsyncCheckpointer:
@pytest.mark.anyio @pytest.mark.anyio
async def test_sqlite_creates_parent_dir_via_to_thread(self): async def test_sqlite_creates_parent_dir_via_to_thread(self):