From 34ec205e1d32fe8a1638c5f099de4712d286a090 Mon Sep 17 00:00:00 2001 From: rayhpeng Date: Wed, 13 May 2026 12:52:34 +0800 Subject: [PATCH] style(storage): format storage package --- .../packages/storage/store/common/__init__.py | 2 +- .../packages/storage/store/common/enums.py | 8 ++--- .../storage/store/config/app_config.py | 10 ++---- .../storage/store/config/storage_config.py | 7 ++--- .../storage/store/persistence/base_model.py | 2 +- .../store/persistence/drivers/mysql.py | 5 +-- .../store/persistence/drivers/sqlite.py | 1 - .../storage/store/persistence/json_compat.py | 5 +-- .../storage/store/persistence/types.py | 1 + .../storage/store/repositories/db/feedback.py | 8 ++--- .../storage/store/repositories/db/run.py | 22 +++---------- .../store/repositories/db/run_event.py | 17 +++------- .../store/repositories/db/thread_meta.py | 31 +++++++++---------- .../packages/storage/store/utils/__init__.py | 2 +- backend/tests/test_storage_json_compat.py | 16 +++------- .../tests/test_storage_persistence_config.py | 4 +-- 16 files changed, 47 insertions(+), 94 deletions(-) diff --git a/backend/packages/storage/store/common/__init__.py b/backend/packages/storage/store/common/__init__.py index e21d63c5d..3b0a4888d 100644 --- a/backend/packages/storage/store/common/__init__.py +++ b/backend/packages/storage/store/common/__init__.py @@ -1,5 +1,5 @@ from .enums import DataBaseType __all__ = [ - 'DataBaseType', + "DataBaseType", ] diff --git a/backend/packages/storage/store/common/enums.py b/backend/packages/storage/store/common/enums.py index 1df841835..d746aa013 100644 --- a/backend/packages/storage/store/common/enums.py +++ b/backend/packages/storage/store/common/enums.py @@ -3,7 +3,7 @@ from enum import IntEnum as SourceIntEnum from enum import StrEnum as SourceStrEnum from typing import Any, TypeVar -T = TypeVar('T', bound=Enum) +T = TypeVar("T", bound=Enum) class _EnumBase: @@ -36,6 +36,6 @@ class StrEnum(_EnumBase, SourceStrEnum): class DataBaseType(StrEnum): """Database type.""" - sqlite = 'sqlite' - mysql = 'mysql' - postgresql = 'postgresql' + sqlite = "sqlite" + mysql = "mysql" + postgresql = "postgresql" diff --git a/backend/packages/storage/store/config/app_config.py b/backend/packages/storage/store/config/app_config.py index fc61d7b02..94c63e640 100644 --- a/backend/packages/storage/store/config/app_config.py +++ b/backend/packages/storage/store/config/app_config.py @@ -92,8 +92,7 @@ class AppConfig(BaseModel): elif os.getenv("DEER_FLOW_CONFIG_PATH"): path = Path(os.getenv("DEER_FLOW_CONFIG_PATH")) if not Path.exists(path): - raise FileNotFoundError( - f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}") + raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}") return path else: for path in _default_config_candidates(): @@ -159,8 +158,7 @@ class AppConfig(BaseModel): if user_version < example_version: logger.warning( - "Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to " - "merge new fields into your config.", + "Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.", user_version, example_version, ) @@ -182,14 +180,12 @@ class AppConfig(BaseModel): return config - _app_config: AppConfig | None = None _app_config_path: Path | None = None _app_config_mtime: float | None = None _app_config_is_custom = False _current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None) -_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", - default=()) +_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=()) def _get_config_mtime(config_path: Path) -> float | None: diff --git a/backend/packages/storage/store/config/storage_config.py b/backend/packages/storage/store/config/storage_config.py index 9f55a19aa..98c86a275 100644 --- a/backend/packages/storage/store/config/storage_config.py +++ b/backend/packages/storage/store/config/storage_config.py @@ -22,17 +22,14 @@ def _strip_legacy_state_prefix(path: str) -> str: if path == ".deer-flow": return "." if path.startswith(prefix): - return path[len(prefix):] + return path[len(prefix) :] return path class StorageConfig(BaseModel): driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field( default="sqlite", - description="Storage driver for both checkpointer and application data. " - "'sqlite' for single-node deployment (default)," - "'postgres' for production multi-node deployment, " - "'mysql' for MySQL databases.", + description="Storage driver for both checkpointer and application data. 'sqlite' for single-node deployment (default),'postgres' for production multi-node deployment, 'mysql' for MySQL databases.", ) sqlite_dir: str = Field( default=".deer-flow/data", diff --git a/backend/packages/storage/store/persistence/base_model.py b/backend/packages/storage/store/persistence/base_model.py index ed9006718..beeef56eb 100644 --- a/backend/packages/storage/store/persistence/base_model.py +++ b/backend/packages/storage/store/persistence/base_model.py @@ -23,7 +23,7 @@ id_key = Annotated[ autoincrement=True, sort_order=-999, comment="Primary key ID", - ) + ), ] diff --git a/backend/packages/storage/store/persistence/drivers/mysql.py b/backend/packages/storage/store/persistence/drivers/mysql.py index c63d10155..72e4efb5e 100644 --- a/backend/packages/storage/store/persistence/drivers/mysql.py +++ b/backend/packages/storage/store/persistence/drivers/mysql.py @@ -16,10 +16,7 @@ def _validate_mysql_driver(db_url: URL) -> str: driver = url.get_driver_name() if driver not in {"aiomysql", "asyncmy"}: - raise ValueError( - "MySQL persistence requires async SQLAlchemy driver " - f"(aiomysql/asyncmy), got: {driver!r}" - ) + raise ValueError(f"MySQL persistence requires async SQLAlchemy driver (aiomysql/asyncmy), got: {driver!r}") return driver diff --git a/backend/packages/storage/store/persistence/drivers/sqlite.py b/backend/packages/storage/store/persistence/drivers/sqlite.py index be669baa5..52b1cea1e 100644 --- a/backend/packages/storage/store/persistence/drivers/sqlite.py +++ b/backend/packages/storage/store/persistence/drivers/sqlite.py @@ -1,4 +1,3 @@ - from __future__ import annotations import json diff --git a/backend/packages/storage/store/persistence/json_compat.py b/backend/packages/storage/store/persistence/json_compat.py index acd85fd34..0d51915c1 100644 --- a/backend/packages/storage/store/persistence/json_compat.py +++ b/backend/packages/storage/store/persistence/json_compat.py @@ -139,10 +139,7 @@ def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: objec if isinstance(value, int): bp = _bind(compiler, value, BigInteger(), **kw) if dialect.int_guard: - return ( - f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} " - f"THEN CAST({extract} AS {dialect.int_cast}) END = {bp})" - ) + return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})" return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})" if isinstance(value, float): bp = _bind(compiler, value, Float(), **kw) diff --git a/backend/packages/storage/store/persistence/types.py b/backend/packages/storage/store/persistence/types.py index 715cdb7a2..4e1eb8cd2 100644 --- a/backend/packages/storage/store/persistence/types.py +++ b/backend/packages/storage/store/persistence/types.py @@ -15,6 +15,7 @@ class AppPersistence: """ Unified runtime persistence bundle. """ + checkpointer: Checkpointer engine: AsyncEngine session_factory: async_sessionmaker[AsyncSession] diff --git a/backend/packages/storage/store/repositories/db/feedback.py b/backend/packages/storage/store/repositories/db/feedback.py index 67a5c05fa..f5c18a0d8 100644 --- a/backend/packages/storage/store/repositories/db/feedback.py +++ b/backend/packages/storage/store/repositories/db/feedback.py @@ -67,9 +67,7 @@ class DbFeedbackRepository(FeedbackRepositoryProtocol): return _to_feedback(model) async def get_feedback(self, feedback_id: str) -> Feedback | None: - result = await self._session.execute( - select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id) - ) + result = await self._session.execute(select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)) model = result.scalar_one_or_none() return _to_feedback(model) if model else None @@ -112,9 +110,7 @@ class DbFeedbackRepository(FeedbackRepositoryProtocol): existing = await self.get_feedback(feedback_id) if existing is None: return False - await self._session.execute( - delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id) - ) + await self._session.execute(delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)) return True async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool: diff --git a/backend/packages/storage/store/repositories/db/run.py b/backend/packages/storage/store/repositories/db/run.py index 146ec1e46..e6c07c45a 100644 --- a/backend/packages/storage/store/repositories/db/run.py +++ b/backend/packages/storage/store/repositories/db/run.py @@ -64,9 +64,7 @@ class DbRunRepository(RunRepositoryProtocol): return _to_run(model) async def get_run(self, run_id: str) -> Run | None: - result = await self._session.execute( - select(RunModel).where(RunModel.run_id == run_id) - ) + result = await self._session.execute(select(RunModel).where(RunModel.run_id == run_id)) model = result.scalar_one_or_none() return _to_run(model) if model else None @@ -85,15 +83,11 @@ class DbRunRepository(RunRepositoryProtocol): result = await self._session.execute(stmt) return [_to_run(m) for m in result.scalars().all()] - async def update_run_status( - self, run_id: str, status: str, *, error: str | None = None - ) -> None: + async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None: values: dict = {"status": status} if error is not None: values["error"] = error - await self._session.execute( - update(RunModel).where(RunModel.run_id == run_id).values(**values) - ) + await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values)) async def delete_run(self, run_id: str) -> None: await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id)) @@ -106,11 +100,7 @@ class DbRunRepository(RunRepositoryProtocol): else: before_dt = datetime.fromisoformat(before) - result = await self._session.execute( - select(RunModel) - .where(RunModel.status == "pending", RunModel.created_time <= before_dt) - .order_by(RunModel.created_time.asc()) - ) + result = await self._session.execute(select(RunModel).where(RunModel.status == "pending", RunModel.created_time <= before_dt).order_by(RunModel.created_time.asc())) return [_to_run(m) for m in result.scalars().all()] async def update_run_completion( @@ -147,9 +137,7 @@ class DbRunRepository(RunRepositoryProtocol): values["last_ai_message"] = last_ai_message[:2000] if error is not None: values["error"] = error - await self._session.execute( - update(RunModel).where(RunModel.run_id == run_id).values(**values) - ) + await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values)) async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: completed = RunModel.status.in_(("success", "error")) diff --git a/backend/packages/storage/store/repositories/db/run_event.py b/backend/packages/storage/store/repositories/db/run_event.py index a8c312e8f..df5c005e6 100644 --- a/backend/packages/storage/store/repositories/db/run_event.py +++ b/backend/packages/storage/store/repositories/db/run_event.py @@ -158,13 +158,10 @@ class DbRunEventRepository(RunEventRepositoryProtocol): after_seq: int | None = None, user_id: str | None = None, ) -> list[RunEvent]: - stmt = ( - select(RunEventModel) - .where( - RunEventModel.thread_id == thread_id, - RunEventModel.run_id == run_id, - RunEventModel.category == "message", - ) + stmt = select(RunEventModel).where( + RunEventModel.thread_id == thread_id, + RunEventModel.run_id == run_id, + RunEventModel.category == "message", ) if user_id is not None: stmt = stmt.where(RunEventModel.user_id == user_id) @@ -182,11 +179,7 @@ class DbRunEventRepository(RunEventRepositoryProtocol): return list(reversed([_to_run_event(row) for row in result.scalars().all()])) async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int: - stmt = ( - select(func.count()) - .select_from(RunEventModel) - .where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message") - ) + stmt = select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message") if user_id is not None: stmt = stmt.where(RunEventModel.user_id == user_id) count = await self._session.scalar(stmt) diff --git a/backend/packages/storage/store/repositories/db/thread_meta.py b/backend/packages/storage/store/repositories/db/thread_meta.py index af06b2da9..6b13b181c 100644 --- a/backend/packages/storage/store/repositories/db/thread_meta.py +++ b/backend/packages/storage/store/repositories/db/thread_meta.py @@ -55,12 +55,12 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol): return _to_thread_meta(model) if model else None async def update_thread_meta( - self, - thread_id: str, - *, - display_name: str | None = None, - status: str | None = None, - metadata: dict[str, Any] | None = None, + self, + thread_id: str, + *, + display_name: str | None = None, + status: str | None = None, + metadata: dict[str, Any] | None = None, ) -> None: values: dict = {} if display_name is not None: @@ -71,21 +71,20 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol): values["meta"] = dict(metadata) if not values: return - await self._session.execute( - update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values)) + await self._session.execute(update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values)) async def delete_thread(self, thread_id: str) -> None: await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id)) async def search_threads( - self, - *, - metadata: dict[str, Any] | None = None, - status: str | None = None, - user_id: str | None = None, - assistant_id: str | None = None, - limit: int = 100, - offset: int = 0, + self, + *, + metadata: dict[str, Any] | None = None, + status: str | None = None, + user_id: str | None = None, + assistant_id: str | None = None, + limit: int = 100, + offset: int = 0, ) -> list[ThreadMeta]: stmt = select(ThreadMetaModel) diff --git a/backend/packages/storage/store/utils/__init__.py b/backend/packages/storage/store/utils/__init__.py index 1ee6c9df9..e693db633 100644 --- a/backend/packages/storage/store/utils/__init__.py +++ b/backend/packages/storage/store/utils/__init__.py @@ -1,3 +1,3 @@ from .timezone import get_timezone -__all__ = ["get_timezone"] \ No newline at end of file +__all__ = ["get_timezone"] diff --git a/backend/tests/test_storage_json_compat.py b/backend/tests/test_storage_json_compat.py index 1ca1645ee..985b77fdf 100644 --- a/backend/tests/test_storage_json_compat.py +++ b/backend/tests/test_storage_json_compat.py @@ -24,12 +24,8 @@ def test_storage_json_match_compiles_sqlite() -> None: table = _table() dialect = create_engine("sqlite://").dialect - assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ( - "json_type(t.data, '$.\"k\"') = 'null'" - ) - assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ( - "json_type(t.data, '$.\"k\"') = 'true'" - ) + assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'null'") + assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'true'") int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) assert "= 'integer'" in int_sql @@ -44,12 +40,8 @@ def test_storage_json_match_compiles_postgres() -> None: table = _table() dialect = postgresql.dialect() - assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ( - "json_typeof(t.data -> 'k') = 'null'" - ) - assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ( - "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')" - ) + assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_typeof(t.data -> 'k') = 'null'") + assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')") int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) assert "CASE WHEN" in int_sql diff --git a/backend/tests/test_storage_persistence_config.py b/backend/tests/test_storage_persistence_config.py index b718d58fb..55dfcf668 100644 --- a/backend/tests/test_storage_persistence_config.py +++ b/backend/tests/test_storage_persistence_config.py @@ -110,9 +110,7 @@ def test_storage_models_import_without_config_file(tmp_path): [ sys.executable, "-c", - "from store.persistence.base_model import UniversalText, id_key; " - "from store.repositories.models import RunEvent; " - "print(UniversalText.__name__, RunEvent.__tablename__, id_key)", + "from store.persistence.base_model import UniversalText, id_key; from store.repositories.models import RunEvent; print(UniversalText.__name__, RunEvent.__tablename__, id_key)", ], check=False, capture_output=True,