style(storage): format storage package
This commit is contained in:
@@ -1,5 +1,5 @@
|
||||
from .enums import DataBaseType
|
||||
|
||||
__all__ = [
|
||||
'DataBaseType',
|
||||
"DataBaseType",
|
||||
]
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import IntEnum as SourceIntEnum
|
||||
from enum import StrEnum as SourceStrEnum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
T = TypeVar('T', bound=Enum)
|
||||
T = TypeVar("T", bound=Enum)
|
||||
|
||||
|
||||
class _EnumBase:
|
||||
@@ -36,6 +36,6 @@ class StrEnum(_EnumBase, SourceStrEnum):
|
||||
class DataBaseType(StrEnum):
|
||||
"""Database type."""
|
||||
|
||||
sqlite = 'sqlite'
|
||||
mysql = 'mysql'
|
||||
postgresql = 'postgresql'
|
||||
sqlite = "sqlite"
|
||||
mysql = "mysql"
|
||||
postgresql = "postgresql"
|
||||
|
||||
@@ -92,8 +92,7 @@ class AppConfig(BaseModel):
|
||||
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(
|
||||
f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
for path in _default_config_candidates():
|
||||
@@ -159,8 +158,7 @@ class AppConfig(BaseModel):
|
||||
|
||||
if user_version < example_version:
|
||||
logger.warning(
|
||||
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to "
|
||||
"merge new fields into your config.",
|
||||
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.",
|
||||
user_version,
|
||||
example_version,
|
||||
)
|
||||
@@ -182,14 +180,12 @@ class AppConfig(BaseModel):
|
||||
return config
|
||||
|
||||
|
||||
|
||||
_app_config: AppConfig | None = None
|
||||
_app_config_path: Path | None = None
|
||||
_app_config_mtime: float | None = None
|
||||
_app_config_is_custom = False
|
||||
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack",
|
||||
default=())
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
|
||||
|
||||
|
||||
def _get_config_mtime(config_path: Path) -> float | None:
|
||||
|
||||
@@ -22,17 +22,14 @@ def _strip_legacy_state_prefix(path: str) -> str:
|
||||
if path == ".deer-flow":
|
||||
return "."
|
||||
if path.startswith(prefix):
|
||||
return path[len(prefix):]
|
||||
return path[len(prefix) :]
|
||||
return path
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field(
|
||||
default="sqlite",
|
||||
description="Storage driver for both checkpointer and application data. "
|
||||
"'sqlite' for single-node deployment (default),"
|
||||
"'postgres' for production multi-node deployment, "
|
||||
"'mysql' for MySQL databases.",
|
||||
description="Storage driver for both checkpointer and application data. 'sqlite' for single-node deployment (default),'postgres' for production multi-node deployment, 'mysql' for MySQL databases.",
|
||||
)
|
||||
sqlite_dir: str = Field(
|
||||
default=".deer-flow/data",
|
||||
|
||||
@@ -23,7 +23,7 @@ id_key = Annotated[
|
||||
autoincrement=True,
|
||||
sort_order=-999,
|
||||
comment="Primary key ID",
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -16,10 +16,7 @@ def _validate_mysql_driver(db_url: URL) -> str:
|
||||
driver = url.get_driver_name()
|
||||
|
||||
if driver not in {"aiomysql", "asyncmy"}:
|
||||
raise ValueError(
|
||||
"MySQL persistence requires async SQLAlchemy driver "
|
||||
f"(aiomysql/asyncmy), got: {driver!r}"
|
||||
)
|
||||
raise ValueError(f"MySQL persistence requires async SQLAlchemy driver (aiomysql/asyncmy), got: {driver!r}")
|
||||
return driver
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
@@ -139,10 +139,7 @@ def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: objec
|
||||
if isinstance(value, int):
|
||||
bp = _bind(compiler, value, BigInteger(), **kw)
|
||||
if dialect.int_guard:
|
||||
return (
|
||||
f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} "
|
||||
f"THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
|
||||
)
|
||||
return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
|
||||
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
|
||||
if isinstance(value, float):
|
||||
bp = _bind(compiler, value, Float(), **kw)
|
||||
|
||||
@@ -15,6 +15,7 @@ class AppPersistence:
|
||||
"""
|
||||
Unified runtime persistence bundle.
|
||||
"""
|
||||
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
|
||||
@@ -67,9 +67,7 @@ class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||
return _to_feedback(model)
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
result = await self._session.execute(select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_feedback(model) if model else None
|
||||
|
||||
@@ -112,9 +110,7 @@ class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||
existing = await self.get_feedback(feedback_id)
|
||||
if existing is None:
|
||||
return False
|
||||
await self._session.execute(
|
||||
delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
await self._session.execute(delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||
return True
|
||||
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||
|
||||
@@ -64,9 +64,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
return _to_run(model)
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
result = await self._session.execute(
|
||||
select(RunModel).where(RunModel.run_id == run_id)
|
||||
)
|
||||
result = await self._session.execute(select(RunModel).where(RunModel.run_id == run_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_run(model) if model else None
|
||||
|
||||
@@ -85,15 +83,11 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_status(
|
||||
self, run_id: str, status: str, *, error: str | None = None
|
||||
) -> None:
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
|
||||
values: dict = {"status": status}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
|
||||
@@ -106,11 +100,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
else:
|
||||
before_dt = datetime.fromisoformat(before)
|
||||
|
||||
result = await self._session.execute(
|
||||
select(RunModel)
|
||||
.where(RunModel.status == "pending", RunModel.created_time <= before_dt)
|
||||
.order_by(RunModel.created_time.asc())
|
||||
)
|
||||
result = await self._session.execute(select(RunModel).where(RunModel.status == "pending", RunModel.created_time <= before_dt).order_by(RunModel.created_time.asc()))
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_completion(
|
||||
@@ -147,9 +137,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = RunModel.status.in_(("success", "error"))
|
||||
|
||||
@@ -158,13 +158,10 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = (
|
||||
select(RunEventModel)
|
||||
.where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
@@ -182,11 +179,7 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(RunEventModel)
|
||||
.where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||
)
|
||||
stmt = select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(stmt)
|
||||
|
||||
@@ -55,12 +55,12 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
return _to_thread_meta(model) if model else None
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
values: dict = {}
|
||||
if display_name is not None:
|
||||
@@ -71,21 +71,20 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
values["meta"] = dict(metadata)
|
||||
if not values:
|
||||
return
|
||||
await self._session.execute(
|
||||
update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||
await self._session.execute(update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
stmt = select(ThreadMetaModel)
|
||||
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .timezone import get_timezone
|
||||
|
||||
__all__ = ["get_timezone"]
|
||||
__all__ = ["get_timezone"]
|
||||
|
||||
@@ -24,12 +24,8 @@ def test_storage_json_match_compiles_sqlite() -> None:
|
||||
table = _table()
|
||||
dialect = create_engine("sqlite://").dialect
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"json_type(t.data, '$.\"k\"') = 'null'"
|
||||
)
|
||||
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"json_type(t.data, '$.\"k\"') = 'true'"
|
||||
)
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'null'")
|
||||
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'true'")
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "= 'integer'" in int_sql
|
||||
@@ -44,12 +40,8 @@ def test_storage_json_match_compiles_postgres() -> None:
|
||||
table = _table()
|
||||
dialect = postgresql.dialect()
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"json_typeof(t.data -> 'k') = 'null'"
|
||||
)
|
||||
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"
|
||||
)
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_typeof(t.data -> 'k') = 'null'")
|
||||
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')")
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "CASE WHEN" in int_sql
|
||||
|
||||
@@ -110,9 +110,7 @@ def test_storage_models_import_without_config_file(tmp_path):
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
"from store.persistence.base_model import UniversalText, id_key; "
|
||||
"from store.repositories.models import RunEvent; "
|
||||
"print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
|
||||
"from store.persistence.base_model import UniversalText, id_key; from store.repositories.models import RunEvent; print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
|
||||
],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
|
||||
Reference in New Issue
Block a user