mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 00:45:57 +00:00
fix(storage): harden sql persistence compatibility
This commit is contained in:
@@ -11,7 +11,7 @@ from store.persistence.shared import close_in_order
|
|||||||
from store.persistence.types import AppPersistence
|
from store.persistence.types import AppPersistence
|
||||||
|
|
||||||
|
|
||||||
def _validate_mysql_driver(db_url: str) -> str:
|
def _validate_mysql_driver(db_url: URL) -> str:
|
||||||
url = make_url(db_url)
|
url = make_url(db_url)
|
||||||
driver = url.get_driver_name()
|
driver = url.get_driver_name()
|
||||||
|
|
||||||
@@ -23,6 +23,10 @@ def _validate_mysql_driver(db_url: str) -> str:
|
|||||||
return driver
|
return driver
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||||
|
return db_url.render_as_string(hide_password=False)
|
||||||
|
|
||||||
|
|
||||||
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||||
_validate_mysql_driver(db_url)
|
_validate_mysql_driver(db_url)
|
||||||
|
|
||||||
@@ -46,7 +50,7 @@ async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size:
|
|||||||
autoflush=False,
|
autoflush=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
saver_cm = AIOMySQLSaver.from_conn_string(db_url)
|
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||||
checkpointer = await saver_cm.__aenter__()
|
checkpointer = await saver_cm.__aenter__()
|
||||||
|
|
||||||
async def setup() -> None:
|
async def setup() -> None:
|
||||||
|
|||||||
@@ -10,6 +10,10 @@ from store.persistence.shared import close_in_order
|
|||||||
from store.persistence.types import AppPersistence
|
from store.persistence.types import AppPersistence
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||||
|
return db_url.set(drivername="postgresql").render_as_string(hide_password=False)
|
||||||
|
|
||||||
|
|
||||||
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
|
||||||
@@ -31,7 +35,7 @@ async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_si
|
|||||||
autoflush=False,
|
autoflush=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
saver_cm = AsyncPostgresSaver.from_conn_string(db_url)
|
saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||||
checkpointer = await saver_cm.__aenter__()
|
checkpointer = await saver_cm.__aenter__()
|
||||||
|
|
||||||
async def setup() -> None:
|
async def setup() -> None:
|
||||||
|
|||||||
@@ -0,0 +1,192 @@
|
|||||||
|
"""Dialect-aware JSON value matching for storage SQLAlchemy repositories."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, Float, String, bindparam
|
||||||
|
from sqlalchemy.ext.compiler import compiles
|
||||||
|
from sqlalchemy.sql.compiler import SQLCompiler
|
||||||
|
from sqlalchemy.sql.expression import ColumnElement
|
||||||
|
from sqlalchemy.sql.visitors import InternalTraversal
|
||||||
|
from sqlalchemy.types import Boolean, TypeEngine
|
||||||
|
|
||||||
|
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||||
|
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
|
||||||
|
|
||||||
|
_INT64_MIN = -(2**63)
|
||||||
|
_INT64_MAX = 2**63 - 1
|
||||||
|
|
||||||
|
|
||||||
|
def validate_metadata_filter_key(key: object) -> bool:
|
||||||
|
"""Return True when *key* is safe for JSON metadata filter SQL paths."""
|
||||||
|
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_metadata_filter_value(value: object) -> bool:
|
||||||
|
"""Return True when *value* can be compiled into a portable JSON predicate."""
|
||||||
|
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
|
||||||
|
return False
|
||||||
|
if isinstance(value, int) and not isinstance(value, bool):
|
||||||
|
return _INT64_MIN <= value <= _INT64_MAX
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class JsonMatch(ColumnElement[bool]):
|
||||||
|
"""Dialect-portable ``column[key] == value`` for JSON columns."""
|
||||||
|
|
||||||
|
inherit_cache = True
|
||||||
|
type = Boolean()
|
||||||
|
_is_implicitly_boolean = True
|
||||||
|
|
||||||
|
_traverse_internals = [
|
||||||
|
("column", InternalTraversal.dp_clauseelement),
|
||||||
|
("key", InternalTraversal.dp_string),
|
||||||
|
("value", InternalTraversal.dp_plain_obj),
|
||||||
|
("value_type", InternalTraversal.dp_string),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None:
|
||||||
|
if not validate_metadata_filter_key(key):
|
||||||
|
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
|
||||||
|
if not validate_metadata_filter_value(value):
|
||||||
|
if isinstance(value, int) and not isinstance(value, bool):
|
||||||
|
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
|
||||||
|
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
|
||||||
|
self.column = column
|
||||||
|
self.key = key
|
||||||
|
self.value = value
|
||||||
|
self.value_type = type(value).__qualname__
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _Dialect:
|
||||||
|
null_type: str
|
||||||
|
num_types: tuple[str, ...]
|
||||||
|
num_cast: str
|
||||||
|
int_types: tuple[str, ...]
|
||||||
|
int_cast: str
|
||||||
|
int_guard: str | None
|
||||||
|
string_type: str
|
||||||
|
bool_type: str | None
|
||||||
|
true_value: str
|
||||||
|
false_value: str
|
||||||
|
|
||||||
|
|
||||||
|
_SQLITE = _Dialect(
|
||||||
|
null_type="null",
|
||||||
|
num_types=("integer", "real"),
|
||||||
|
num_cast="REAL",
|
||||||
|
int_types=("integer",),
|
||||||
|
int_cast="INTEGER",
|
||||||
|
int_guard=None,
|
||||||
|
string_type="text",
|
||||||
|
bool_type=None,
|
||||||
|
true_value="true",
|
||||||
|
false_value="false",
|
||||||
|
)
|
||||||
|
|
||||||
|
_POSTGRES = _Dialect(
|
||||||
|
null_type="null",
|
||||||
|
num_types=("number",),
|
||||||
|
num_cast="DOUBLE PRECISION",
|
||||||
|
int_types=("number",),
|
||||||
|
int_cast="BIGINT",
|
||||||
|
int_guard="'^-?[0-9]+$'",
|
||||||
|
string_type="string",
|
||||||
|
bool_type="boolean",
|
||||||
|
true_value="true",
|
||||||
|
false_value="false",
|
||||||
|
)
|
||||||
|
|
||||||
|
_MYSQL = _Dialect(
|
||||||
|
null_type="NULL",
|
||||||
|
num_types=("INTEGER", "DOUBLE", "DECIMAL"),
|
||||||
|
num_cast="DOUBLE",
|
||||||
|
int_types=("INTEGER",),
|
||||||
|
int_cast="SIGNED",
|
||||||
|
int_guard=None,
|
||||||
|
string_type="STRING",
|
||||||
|
bool_type="BOOLEAN",
|
||||||
|
true_value="true",
|
||||||
|
false_value="false",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
|
||||||
|
param = bindparam(None, value, type_=sa_type)
|
||||||
|
return compiler.process(param, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
|
||||||
|
if len(types) == 1:
|
||||||
|
return f"{typeof} = '{types[0]}'"
|
||||||
|
quoted = ", ".join(f"'{type_name}'" for type_name in types)
|
||||||
|
return f"{typeof} IN ({quoted})"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return f"{typeof} = '{dialect.null_type}'"
|
||||||
|
if isinstance(value, bool):
|
||||||
|
bool_str = dialect.true_value if value else dialect.false_value
|
||||||
|
if dialect.bool_type is None:
|
||||||
|
return f"{typeof} = '{bool_str}'"
|
||||||
|
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
|
||||||
|
if isinstance(value, int):
|
||||||
|
bp = _bind(compiler, value, BigInteger(), **kw)
|
||||||
|
if dialect.int_guard:
|
||||||
|
return (
|
||||||
|
f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} "
|
||||||
|
f"THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
|
||||||
|
)
|
||||||
|
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
|
||||||
|
if isinstance(value, float):
|
||||||
|
bp = _bind(compiler, value, Float(), **kw)
|
||||||
|
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
|
||||||
|
bp = _bind(compiler, str(value), String(), **kw)
|
||||||
|
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(JsonMatch, "sqlite")
|
||||||
|
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||||
|
if not validate_metadata_filter_key(element.key):
|
||||||
|
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||||
|
col = compiler.process(element.column, **kw)
|
||||||
|
path = f'$."{element.key}"'
|
||||||
|
typeof = f"json_type({col}, '{path}')"
|
||||||
|
extract = f"json_extract({col}, '{path}')"
|
||||||
|
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(JsonMatch, "postgresql")
|
||||||
|
def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||||
|
if not validate_metadata_filter_key(element.key):
|
||||||
|
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||||
|
col = compiler.process(element.column, **kw)
|
||||||
|
typeof = f"json_typeof({col} -> '{element.key}')"
|
||||||
|
extract = f"({col} ->> '{element.key}')"
|
||||||
|
return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(JsonMatch, "mysql")
|
||||||
|
def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||||
|
if not validate_metadata_filter_key(element.key):
|
||||||
|
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||||
|
col = compiler.process(element.column, **kw)
|
||||||
|
path = f'$."{element.key}"'
|
||||||
|
typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))"
|
||||||
|
extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))"
|
||||||
|
return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(JsonMatch)
|
||||||
|
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||||
|
raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}")
|
||||||
|
|
||||||
|
|
||||||
|
def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch:
|
||||||
|
return JsonMatch(column, key, value)
|
||||||
@@ -3,6 +3,7 @@ from store.repositories.contracts import (
|
|||||||
FeedbackAggregate,
|
FeedbackAggregate,
|
||||||
FeedbackCreate,
|
FeedbackCreate,
|
||||||
FeedbackRepositoryProtocol,
|
FeedbackRepositoryProtocol,
|
||||||
|
InvalidMetadataFilterError,
|
||||||
Run,
|
Run,
|
||||||
RunCreate,
|
RunCreate,
|
||||||
RunEvent,
|
RunEvent,
|
||||||
@@ -30,6 +31,7 @@ __all__ = [
|
|||||||
"FeedbackAggregate",
|
"FeedbackAggregate",
|
||||||
"FeedbackCreate",
|
"FeedbackCreate",
|
||||||
"FeedbackRepositoryProtocol",
|
"FeedbackRepositoryProtocol",
|
||||||
|
"InvalidMetadataFilterError",
|
||||||
"Run",
|
"Run",
|
||||||
"RunCreate",
|
"RunCreate",
|
||||||
"RunEvent",
|
"RunEvent",
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from store.repositories.contracts.run_event import (
|
|||||||
RunEventRepositoryProtocol,
|
RunEventRepositoryProtocol,
|
||||||
)
|
)
|
||||||
from store.repositories.contracts.thread_meta import (
|
from store.repositories.contracts.thread_meta import (
|
||||||
|
InvalidMetadataFilterError,
|
||||||
ThreadMeta,
|
ThreadMeta,
|
||||||
ThreadMetaCreate,
|
ThreadMetaCreate,
|
||||||
ThreadMetaRepositoryProtocol,
|
ThreadMetaRepositoryProtocol,
|
||||||
@@ -37,6 +38,7 @@ __all__ = [
|
|||||||
"RunEventCreate",
|
"RunEventCreate",
|
||||||
"RunEventRepositoryProtocol",
|
"RunEventRepositoryProtocol",
|
||||||
"RunRepositoryProtocol",
|
"RunRepositoryProtocol",
|
||||||
|
"InvalidMetadataFilterError",
|
||||||
"ThreadMeta",
|
"ThreadMeta",
|
||||||
"ThreadMetaCreate",
|
"ThreadMetaCreate",
|
||||||
"ThreadMetaRepositoryProtocol",
|
"ThreadMetaRepositoryProtocol",
|
||||||
|
|||||||
@@ -6,6 +6,10 @@ from typing import Any, Protocol
|
|||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidMetadataFilterError(ValueError):
|
||||||
|
"""Raised when all client-supplied metadata filters are rejected."""
|
||||||
|
|
||||||
|
|
||||||
class ThreadMetaCreate(BaseModel):
|
class ThreadMetaCreate(BaseModel):
|
||||||
model_config = ConfigDict(extra="forbid")
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
|||||||
@@ -153,9 +153,10 @@ class DbRunRepository(RunRepositoryProtocol):
|
|||||||
|
|
||||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||||
completed = RunModel.status.in_(("success", "error"))
|
completed = RunModel.status.in_(("success", "error"))
|
||||||
|
model_expr = func.coalesce(RunModel.model_name, "unknown")
|
||||||
stmt = (
|
stmt = (
|
||||||
select(
|
select(
|
||||||
func.coalesce(RunModel.model_name, "unknown").label("model"),
|
model_expr.label("model"),
|
||||||
func.count().label("runs"),
|
func.count().label("runs"),
|
||||||
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
|
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
|
||||||
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
|
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
|
||||||
@@ -165,7 +166,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
|||||||
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
|
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
|
||||||
)
|
)
|
||||||
.where(RunModel.thread_id == thread_id, completed)
|
.where(RunModel.thread_id == thread_id, completed)
|
||||||
.group_by(func.coalesce(RunModel.model_name, "unknown"))
|
.group_by(model_expr)
|
||||||
)
|
)
|
||||||
|
|
||||||
rows = (await self._session.execute(stmt)).all()
|
rows = (await self._session.execute(stmt)).all()
|
||||||
|
|||||||
@@ -56,8 +56,10 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
|||||||
seq_by_thread: dict[str, int] = {}
|
seq_by_thread: dict[str, int] = {}
|
||||||
for thread_id in thread_ids:
|
for thread_id in thread_ids:
|
||||||
max_seq = await self._session.scalar(
|
max_seq = await self._session.scalar(
|
||||||
select(func.max(RunEventModel.seq))
|
select(RunEventModel.seq)
|
||||||
.where(RunEventModel.thread_id == thread_id)
|
.where(RunEventModel.thread_id == thread_id)
|
||||||
|
.order_by(RunEventModel.seq.desc())
|
||||||
|
.limit(1)
|
||||||
.with_for_update()
|
.with_for_update()
|
||||||
)
|
)
|
||||||
seq_by_thread[thread_id] = max_seq or 0
|
seq_by_thread[thread_id] = max_seq or 0
|
||||||
|
|||||||
@@ -1,13 +1,22 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import delete, select, update
|
from sqlalchemy import delete, select, update
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol
|
from store.persistence.json_compat import json_match
|
||||||
|
from store.repositories.contracts.thread_meta import (
|
||||||
|
InvalidMetadataFilterError,
|
||||||
|
ThreadMeta,
|
||||||
|
ThreadMetaCreate,
|
||||||
|
ThreadMetaRepositoryProtocol,
|
||||||
|
)
|
||||||
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||||
return ThreadMeta(
|
return ThreadMeta(
|
||||||
@@ -87,10 +96,18 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
|||||||
if assistant_id is not None:
|
if assistant_id is not None:
|
||||||
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||||
if metadata:
|
if metadata:
|
||||||
|
applied = 0
|
||||||
for key, value in metadata.items():
|
for key, value in metadata.items():
|
||||||
stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value))
|
try:
|
||||||
|
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
|
||||||
|
applied += 1
|
||||||
|
except (ValueError, TypeError) as exc:
|
||||||
|
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
|
||||||
|
if applied == 0:
|
||||||
|
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
|
||||||
|
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
|
||||||
|
|
||||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc())
|
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
|
||||||
stmt = stmt.limit(limit).offset(offset)
|
stmt = stmt.limit(limit).offset(offset)
|
||||||
|
|
||||||
result = await self._session.execute(stmt)
|
result = await self._session.execute(stmt)
|
||||||
|
|||||||
@@ -0,0 +1,89 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import Column, MetaData, String, Table
|
||||||
|
from sqlalchemy.dialects import mysql, postgresql
|
||||||
|
from sqlalchemy.types import JSON
|
||||||
|
|
||||||
|
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||||
|
|
||||||
|
from store.persistence.json_compat import json_match
|
||||||
|
|
||||||
|
|
||||||
|
def _table():
|
||||||
|
metadata = MetaData()
|
||||||
|
return Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_json_match_compiles_sqlite() -> None:
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
table = _table()
|
||||||
|
dialect = create_engine("sqlite://").dialect
|
||||||
|
|
||||||
|
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||||
|
"json_type(t.data, '$.\"k\"') = 'null'"
|
||||||
|
)
|
||||||
|
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||||
|
"json_type(t.data, '$.\"k\"') = 'true'"
|
||||||
|
)
|
||||||
|
|
||||||
|
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "= 'integer'" in int_sql
|
||||||
|
assert "CAST" in int_sql
|
||||||
|
|
||||||
|
float_sql = str(json_match(table.c.data, "k", 3.14).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "IN ('integer', 'real')" in float_sql
|
||||||
|
assert "REAL" in float_sql
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_json_match_compiles_postgres() -> None:
|
||||||
|
table = _table()
|
||||||
|
dialect = postgresql.dialect()
|
||||||
|
|
||||||
|
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||||
|
"json_typeof(t.data -> 'k') = 'null'"
|
||||||
|
)
|
||||||
|
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||||
|
"(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"
|
||||||
|
)
|
||||||
|
|
||||||
|
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "CASE WHEN" in int_sql
|
||||||
|
assert "BIGINT" in int_sql
|
||||||
|
assert "'^-?[0-9]+$'" in int_sql
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_json_match_compiles_mysql() -> None:
|
||||||
|
table = _table()
|
||||||
|
dialect = mysql.dialect()
|
||||||
|
|
||||||
|
null_sql = str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert null_sql == "JSON_TYPE(JSON_EXTRACT(t.data, '$.\"k\"')) = 'NULL'"
|
||||||
|
|
||||||
|
bool_sql = str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "JSON_TYPE(JSON_EXTRACT" in bool_sql
|
||||||
|
assert "= 'BOOLEAN'" in bool_sql
|
||||||
|
assert "= 'true'" in bool_sql
|
||||||
|
|
||||||
|
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "= 'INTEGER'" in int_sql
|
||||||
|
assert "SIGNED" in int_sql
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_json_match_rejects_unsafe_keys_and_values() -> None:
|
||||||
|
table = _table()
|
||||||
|
|
||||||
|
for bad_key in ["a.b", "bad;key", "with space", "", 42, None]:
|
||||||
|
with pytest.raises(ValueError, match="JsonMatch key must match"):
|
||||||
|
json_match(table.c.data, bad_key, "x") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
for bad_value in [[], {}, object()]:
|
||||||
|
with pytest.raises(TypeError, match="JsonMatch value must be"):
|
||||||
|
json_match(table.c.data, "k", bad_value)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="out of signed 64-bit range"):
|
||||||
|
json_match(table.c.data, "k", 2**63)
|
||||||
@@ -12,6 +12,7 @@ os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().pare
|
|||||||
from store.persistence import create_persistence_from_database_config
|
from store.persistence import create_persistence_from_database_config
|
||||||
from store.repositories import (
|
from store.repositories import (
|
||||||
FeedbackCreate,
|
FeedbackCreate,
|
||||||
|
InvalidMetadataFilterError,
|
||||||
RunCreate,
|
RunCreate,
|
||||||
RunEventCreate,
|
RunEventCreate,
|
||||||
ThreadMetaCreate,
|
ThreadMetaCreate,
|
||||||
@@ -175,6 +176,77 @@ async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
|
|||||||
await persistence.aclose()
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_thread_meta_metadata_filters_are_type_safe(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-true", metadata={"value": True}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-false", metadata={"value": False}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="int-one", metadata={"value": 1}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="null-value", metadata={"value": None}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="missing-value", metadata={"other": "x"}))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
assert [row.thread_id for row in await repo.search_threads(metadata={"value": True})] == ["bool-true"]
|
||||||
|
assert [row.thread_id for row in await repo.search_threads(metadata={"value": False})] == ["bool-false"]
|
||||||
|
assert [row.thread_id for row in await repo.search_threads(metadata={"value": 1})] == ["int-one"]
|
||||||
|
assert [row.thread_id for row in await repo.search_threads(metadata={"value": None})] == ["null-value"]
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_thread_meta_metadata_filters_paginate_after_sql_match(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
for index in range(30):
|
||||||
|
metadata = {"target": "yes"} if index % 3 == 0 else {"target": "no"}
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id=f"thread-{index:02d}", metadata=metadata))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
first_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=0)
|
||||||
|
second_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=3)
|
||||||
|
last_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=9)
|
||||||
|
|
||||||
|
assert len(first_page) == 3
|
||||||
|
assert len(second_page) == 3
|
||||||
|
assert len(last_page) == 1
|
||||||
|
assert {row.thread_id for row in first_page}.isdisjoint({row.thread_id for row in second_page})
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_thread_meta_metadata_filter_rejects_invalid_entries(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-1", metadata={"env": "prod"}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-2", metadata={"env": "staging"}))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
partial = await repo.search_threads(metadata={"env": "prod", "bad;key": "ignored"})
|
||||||
|
assert [row.thread_id for row in partial] == ["thread-1"]
|
||||||
|
|
||||||
|
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||||
|
await repo.search_threads(metadata={"bad;key": "x"})
|
||||||
|
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||||
|
await repo.search_threads(metadata={"env": ["prod", "staging"]})
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
|
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
|
||||||
persistence = await _make_persistence(tmp_path)
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
|||||||
Reference in New Issue
Block a user