fix(storage): harden sql persistence compatibility
This commit is contained in:
@@ -11,7 +11,7 @@ from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _validate_mysql_driver(db_url: str) -> str:
|
||||
def _validate_mysql_driver(db_url: URL) -> str:
|
||||
url = make_url(db_url)
|
||||
driver = url.get_driver_name()
|
||||
|
||||
@@ -23,6 +23,10 @@ def _validate_mysql_driver(db_url: str) -> str:
|
||||
return driver
|
||||
|
||||
|
||||
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||
return db_url.render_as_string(hide_password=False)
|
||||
|
||||
|
||||
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
_validate_mysql_driver(db_url)
|
||||
|
||||
@@ -46,7 +50,7 @@ async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size:
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(db_url)
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
|
||||
@@ -10,6 +10,10 @@ from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||
return db_url.set(drivername="postgresql").render_as_string(hide_password=False)
|
||||
|
||||
|
||||
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
@@ -31,7 +35,7 @@ async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_si
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(db_url)
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
"""Dialect-aware JSON value matching for storage SQLAlchemy repositories."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import BigInteger, Float, String, bindparam
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.compiler import SQLCompiler
|
||||
from sqlalchemy.sql.expression import ColumnElement
|
||||
from sqlalchemy.sql.visitors import InternalTraversal
|
||||
from sqlalchemy.types import Boolean, TypeEngine
|
||||
|
||||
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
|
||||
|
||||
_INT64_MIN = -(2**63)
|
||||
_INT64_MAX = 2**63 - 1
|
||||
|
||||
|
||||
def validate_metadata_filter_key(key: object) -> bool:
|
||||
"""Return True when *key* is safe for JSON metadata filter SQL paths."""
|
||||
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
|
||||
|
||||
|
||||
def validate_metadata_filter_value(value: object) -> bool:
|
||||
"""Return True when *value* can be compiled into a portable JSON predicate."""
|
||||
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
|
||||
return False
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
return _INT64_MIN <= value <= _INT64_MAX
|
||||
return True
|
||||
|
||||
|
||||
class JsonMatch(ColumnElement[bool]):
|
||||
"""Dialect-portable ``column[key] == value`` for JSON columns."""
|
||||
|
||||
inherit_cache = True
|
||||
type = Boolean()
|
||||
_is_implicitly_boolean = True
|
||||
|
||||
_traverse_internals = [
|
||||
("column", InternalTraversal.dp_clauseelement),
|
||||
("key", InternalTraversal.dp_string),
|
||||
("value", InternalTraversal.dp_plain_obj),
|
||||
("value_type", InternalTraversal.dp_string),
|
||||
]
|
||||
|
||||
def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None:
|
||||
if not validate_metadata_filter_key(key):
|
||||
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
|
||||
if not validate_metadata_filter_value(value):
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
|
||||
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
|
||||
self.column = column
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.value_type = type(value).__qualname__
|
||||
super().__init__()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _Dialect:
|
||||
null_type: str
|
||||
num_types: tuple[str, ...]
|
||||
num_cast: str
|
||||
int_types: tuple[str, ...]
|
||||
int_cast: str
|
||||
int_guard: str | None
|
||||
string_type: str
|
||||
bool_type: str | None
|
||||
true_value: str
|
||||
false_value: str
|
||||
|
||||
|
||||
_SQLITE = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("integer", "real"),
|
||||
num_cast="REAL",
|
||||
int_types=("integer",),
|
||||
int_cast="INTEGER",
|
||||
int_guard=None,
|
||||
string_type="text",
|
||||
bool_type=None,
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
_POSTGRES = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("number",),
|
||||
num_cast="DOUBLE PRECISION",
|
||||
int_types=("number",),
|
||||
int_cast="BIGINT",
|
||||
int_guard="'^-?[0-9]+$'",
|
||||
string_type="string",
|
||||
bool_type="boolean",
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
_MYSQL = _Dialect(
|
||||
null_type="NULL",
|
||||
num_types=("INTEGER", "DOUBLE", "DECIMAL"),
|
||||
num_cast="DOUBLE",
|
||||
int_types=("INTEGER",),
|
||||
int_cast="SIGNED",
|
||||
int_guard=None,
|
||||
string_type="STRING",
|
||||
bool_type="BOOLEAN",
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
|
||||
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
|
||||
param = bindparam(None, value, type_=sa_type)
|
||||
return compiler.process(param, **kw)
|
||||
|
||||
|
||||
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
|
||||
if len(types) == 1:
|
||||
return f"{typeof} = '{types[0]}'"
|
||||
quoted = ", ".join(f"'{type_name}'" for type_name in types)
|
||||
return f"{typeof} IN ({quoted})"
|
||||
|
||||
|
||||
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
|
||||
if value is None:
|
||||
return f"{typeof} = '{dialect.null_type}'"
|
||||
if isinstance(value, bool):
|
||||
bool_str = dialect.true_value if value else dialect.false_value
|
||||
if dialect.bool_type is None:
|
||||
return f"{typeof} = '{bool_str}'"
|
||||
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
|
||||
if isinstance(value, int):
|
||||
bp = _bind(compiler, value, BigInteger(), **kw)
|
||||
if dialect.int_guard:
|
||||
return (
|
||||
f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} "
|
||||
f"THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
|
||||
)
|
||||
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
|
||||
if isinstance(value, float):
|
||||
bp = _bind(compiler, value, Float(), **kw)
|
||||
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
|
||||
bp = _bind(compiler, str(value), String(), **kw)
|
||||
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
|
||||
|
||||
|
||||
@compiles(JsonMatch, "sqlite")
|
||||
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
path = f'$."{element.key}"'
|
||||
typeof = f"json_type({col}, '{path}')"
|
||||
extract = f"json_extract({col}, '{path}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch, "postgresql")
|
||||
def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
typeof = f"json_typeof({col} -> '{element.key}')"
|
||||
extract = f"({col} ->> '{element.key}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch, "mysql")
|
||||
def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
path = f'$."{element.key}"'
|
||||
typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))"
|
||||
extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch)
|
||||
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}")
|
||||
|
||||
|
||||
def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch:
|
||||
return JsonMatch(column, key, value)
|
||||
@@ -3,6 +3,7 @@ from store.repositories.contracts import (
|
||||
FeedbackAggregate,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
InvalidMetadataFilterError,
|
||||
Run,
|
||||
RunCreate,
|
||||
RunEvent,
|
||||
@@ -30,6 +31,7 @@ __all__ = [
|
||||
"FeedbackAggregate",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"InvalidMetadataFilterError",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
|
||||
@@ -15,6 +15,7 @@ from store.repositories.contracts.run_event import (
|
||||
RunEventRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
InvalidMetadataFilterError,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
@@ -37,6 +38,7 @@ __all__ = [
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"InvalidMetadataFilterError",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
|
||||
@@ -6,6 +6,10 @@ from typing import Any, Protocol
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class InvalidMetadataFilterError(ValueError):
|
||||
"""Raised when all client-supplied metadata filters are rejected."""
|
||||
|
||||
|
||||
class ThreadMetaCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@@ -153,9 +153,10 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = RunModel.status.in_(("success", "error"))
|
||||
model_expr = func.coalesce(RunModel.model_name, "unknown")
|
||||
stmt = (
|
||||
select(
|
||||
func.coalesce(RunModel.model_name, "unknown").label("model"),
|
||||
model_expr.label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
|
||||
@@ -165,7 +166,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(RunModel.thread_id == thread_id, completed)
|
||||
.group_by(func.coalesce(RunModel.model_name, "unknown"))
|
||||
.group_by(model_expr)
|
||||
)
|
||||
|
||||
rows = (await self._session.execute(stmt)).all()
|
||||
|
||||
@@ -56,8 +56,10 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
seq_by_thread: dict[str, int] = {}
|
||||
for thread_id in thread_ids:
|
||||
max_seq = await self._session.scalar(
|
||||
select(func.max(RunEventModel.seq))
|
||||
select(RunEventModel.seq)
|
||||
.where(RunEventModel.thread_id == thread_id)
|
||||
.order_by(RunEventModel.seq.desc())
|
||||
.limit(1)
|
||||
.with_for_update()
|
||||
)
|
||||
seq_by_thread[thread_id] = max_seq or 0
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol
|
||||
from store.persistence.json_compat import json_match
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
InvalidMetadataFilterError,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||
return ThreadMeta(
|
||||
@@ -87,10 +96,18 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
if assistant_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||
if metadata:
|
||||
applied = 0
|
||||
for key, value in metadata.items():
|
||||
stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value))
|
||||
try:
|
||||
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
|
||||
applied += 1
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
|
||||
if applied == 0:
|
||||
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
|
||||
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
|
||||
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc())
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
|
||||
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, MetaData, String, Table
|
||||
from sqlalchemy.dialects import mysql, postgresql
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.persistence.json_compat import json_match
|
||||
|
||||
|
||||
def _table():
|
||||
metadata = MetaData()
|
||||
return Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_sqlite() -> None:
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
table = _table()
|
||||
dialect = create_engine("sqlite://").dialect
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"json_type(t.data, '$.\"k\"') = 'null'"
|
||||
)
|
||||
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"json_type(t.data, '$.\"k\"') = 'true'"
|
||||
)
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "= 'integer'" in int_sql
|
||||
assert "CAST" in int_sql
|
||||
|
||||
float_sql = str(json_match(table.c.data, "k", 3.14).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "IN ('integer', 'real')" in float_sql
|
||||
assert "REAL" in float_sql
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_postgres() -> None:
|
||||
table = _table()
|
||||
dialect = postgresql.dialect()
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"json_typeof(t.data -> 'k') = 'null'"
|
||||
)
|
||||
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"
|
||||
)
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "CASE WHEN" in int_sql
|
||||
assert "BIGINT" in int_sql
|
||||
assert "'^-?[0-9]+$'" in int_sql
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_mysql() -> None:
|
||||
table = _table()
|
||||
dialect = mysql.dialect()
|
||||
|
||||
null_sql = str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert null_sql == "JSON_TYPE(JSON_EXTRACT(t.data, '$.\"k\"')) = 'NULL'"
|
||||
|
||||
bool_sql = str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "JSON_TYPE(JSON_EXTRACT" in bool_sql
|
||||
assert "= 'BOOLEAN'" in bool_sql
|
||||
assert "= 'true'" in bool_sql
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "= 'INTEGER'" in int_sql
|
||||
assert "SIGNED" in int_sql
|
||||
|
||||
|
||||
def test_storage_json_match_rejects_unsafe_keys_and_values() -> None:
|
||||
table = _table()
|
||||
|
||||
for bad_key in ["a.b", "bad;key", "with space", "", 42, None]:
|
||||
with pytest.raises(ValueError, match="JsonMatch key must match"):
|
||||
json_match(table.c.data, bad_key, "x") # type: ignore[arg-type]
|
||||
|
||||
for bad_value in [[], {}, object()]:
|
||||
with pytest.raises(TypeError, match="JsonMatch value must be"):
|
||||
json_match(table.c.data, "k", bad_value)
|
||||
|
||||
with pytest.raises(TypeError, match="out of signed 64-bit range"):
|
||||
json_match(table.c.data, "k", 2**63)
|
||||
@@ -12,6 +12,7 @@ os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().pare
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import (
|
||||
FeedbackCreate,
|
||||
InvalidMetadataFilterError,
|
||||
RunCreate,
|
||||
RunEventCreate,
|
||||
ThreadMetaCreate,
|
||||
@@ -175,6 +176,77 @@ async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filters_are_type_safe(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-true", metadata={"value": True}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-false", metadata={"value": False}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="int-one", metadata={"value": 1}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="null-value", metadata={"value": None}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="missing-value", metadata={"other": "x"}))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": True})] == ["bool-true"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": False})] == ["bool-false"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": 1})] == ["int-one"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": None})] == ["null-value"]
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filters_paginate_after_sql_match(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
for index in range(30):
|
||||
metadata = {"target": "yes"} if index % 3 == 0 else {"target": "no"}
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id=f"thread-{index:02d}", metadata=metadata))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
first_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=0)
|
||||
second_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=3)
|
||||
last_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=9)
|
||||
|
||||
assert len(first_page) == 3
|
||||
assert len(second_page) == 3
|
||||
assert len(last_page) == 1
|
||||
assert {row.thread_id for row in first_page}.isdisjoint({row.thread_id for row in second_page})
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filter_rejects_invalid_entries(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-1", metadata={"env": "prod"}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-2", metadata={"env": "staging"}))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
partial = await repo.search_threads(metadata={"env": "prod", "bad;key": "ignored"})
|
||||
assert [row.thread_id for row in partial] == ["thread-1"]
|
||||
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search_threads(metadata={"bad;key": "x"})
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search_threads(metadata={"env": ["prod", "staging"]})
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
|
||||
Reference in New Issue
Block a user