From d3066a17460bbede81c92468b71dc86f245a1d99 Mon Sep 17 00:00:00 2001 From: rayhpeng Date: Wed, 13 May 2026 11:26:25 +0800 Subject: [PATCH] fix(storage): harden sql persistence compatibility --- .../store/persistence/drivers/mysql.py | 8 +- .../store/persistence/drivers/postgres.py | 6 +- .../storage/store/persistence/json_compat.py | 192 ++++++++++++++++++ .../storage/store/repositories/__init__.py | 2 + .../store/repositories/contracts/__init__.py | 2 + .../repositories/contracts/thread_meta.py | 4 + .../storage/store/repositories/db/run.py | 5 +- .../store/repositories/db/run_event.py | 4 +- .../store/repositories/db/thread_meta.py | 23 ++- backend/tests/test_storage_json_compat.py | 89 ++++++++ backend/tests/test_storage_repositories.py | 72 +++++++ 11 files changed, 398 insertions(+), 9 deletions(-) create mode 100644 backend/packages/storage/store/persistence/json_compat.py create mode 100644 backend/tests/test_storage_json_compat.py diff --git a/backend/packages/storage/store/persistence/drivers/mysql.py b/backend/packages/storage/store/persistence/drivers/mysql.py index f68b117b7..c63d10155 100644 --- a/backend/packages/storage/store/persistence/drivers/mysql.py +++ b/backend/packages/storage/store/persistence/drivers/mysql.py @@ -11,7 +11,7 @@ from store.persistence.shared import close_in_order from store.persistence.types import AppPersistence -def _validate_mysql_driver(db_url: str) -> str: +def _validate_mysql_driver(db_url: URL) -> str: url = make_url(db_url) driver = url.get_driver_name() @@ -23,6 +23,10 @@ def _validate_mysql_driver(db_url: str) -> str: return driver +def _checkpoint_conn_string(db_url: URL) -> str: + return db_url.render_as_string(hide_password=False) + + async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence: _validate_mysql_driver(db_url) @@ -46,7 +50,7 @@ async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: autoflush=False, ) - saver_cm = AIOMySQLSaver.from_conn_string(db_url) + saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url)) checkpointer = await saver_cm.__aenter__() async def setup() -> None: diff --git a/backend/packages/storage/store/persistence/drivers/postgres.py b/backend/packages/storage/store/persistence/drivers/postgres.py index 99b98a4ff..57f9d5cb7 100644 --- a/backend/packages/storage/store/persistence/drivers/postgres.py +++ b/backend/packages/storage/store/persistence/drivers/postgres.py @@ -10,6 +10,10 @@ from store.persistence.shared import close_in_order from store.persistence.types import AppPersistence +def _checkpoint_conn_string(db_url: URL) -> str: + return db_url.set(drivername="postgresql").render_as_string(hide_password=False) + + async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence: from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver @@ -31,7 +35,7 @@ async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_si autoflush=False, ) - saver_cm = AsyncPostgresSaver.from_conn_string(db_url) + saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url)) checkpointer = await saver_cm.__aenter__() async def setup() -> None: diff --git a/backend/packages/storage/store/persistence/json_compat.py b/backend/packages/storage/store/persistence/json_compat.py new file mode 100644 index 000000000..acd85fd34 --- /dev/null +++ b/backend/packages/storage/store/persistence/json_compat.py @@ -0,0 +1,192 @@ +"""Dialect-aware JSON value matching for storage SQLAlchemy repositories.""" + +from __future__ import annotations + +import re +from dataclasses import dataclass +from typing import Any + +from sqlalchemy import BigInteger, Float, String, bindparam +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.sql.compiler import SQLCompiler +from sqlalchemy.sql.expression import ColumnElement +from sqlalchemy.sql.visitors import InternalTraversal +from sqlalchemy.types import Boolean, TypeEngine + +_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$") +ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str) + +_INT64_MIN = -(2**63) +_INT64_MAX = 2**63 - 1 + + +def validate_metadata_filter_key(key: object) -> bool: + """Return True when *key* is safe for JSON metadata filter SQL paths.""" + return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key)) + + +def validate_metadata_filter_value(value: object) -> bool: + """Return True when *value* can be compiled into a portable JSON predicate.""" + if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES): + return False + if isinstance(value, int) and not isinstance(value, bool): + return _INT64_MIN <= value <= _INT64_MAX + return True + + +class JsonMatch(ColumnElement[bool]): + """Dialect-portable ``column[key] == value`` for JSON columns.""" + + inherit_cache = True + type = Boolean() + _is_implicitly_boolean = True + + _traverse_internals = [ + ("column", InternalTraversal.dp_clauseelement), + ("key", InternalTraversal.dp_string), + ("value", InternalTraversal.dp_plain_obj), + ("value_type", InternalTraversal.dp_string), + ] + + def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None: + if not validate_metadata_filter_key(key): + raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}") + if not validate_metadata_filter_value(value): + if isinstance(value, int) and not isinstance(value, bool): + raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}") + raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}") + self.column = column + self.key = key + self.value = value + self.value_type = type(value).__qualname__ + super().__init__() + + +@dataclass(frozen=True) +class _Dialect: + null_type: str + num_types: tuple[str, ...] + num_cast: str + int_types: tuple[str, ...] + int_cast: str + int_guard: str | None + string_type: str + bool_type: str | None + true_value: str + false_value: str + + +_SQLITE = _Dialect( + null_type="null", + num_types=("integer", "real"), + num_cast="REAL", + int_types=("integer",), + int_cast="INTEGER", + int_guard=None, + string_type="text", + bool_type=None, + true_value="true", + false_value="false", +) + +_POSTGRES = _Dialect( + null_type="null", + num_types=("number",), + num_cast="DOUBLE PRECISION", + int_types=("number",), + int_cast="BIGINT", + int_guard="'^-?[0-9]+$'", + string_type="string", + bool_type="boolean", + true_value="true", + false_value="false", +) + +_MYSQL = _Dialect( + null_type="NULL", + num_types=("INTEGER", "DOUBLE", "DECIMAL"), + num_cast="DOUBLE", + int_types=("INTEGER",), + int_cast="SIGNED", + int_guard=None, + string_type="STRING", + bool_type="BOOLEAN", + true_value="true", + false_value="false", +) + + +def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str: + param = bindparam(None, value, type_=sa_type) + return compiler.process(param, **kw) + + +def _type_check(typeof: str, types: tuple[str, ...]) -> str: + if len(types) == 1: + return f"{typeof} = '{types[0]}'" + quoted = ", ".join(f"'{type_name}'" for type_name in types) + return f"{typeof} IN ({quoted})" + + +def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str: + if value is None: + return f"{typeof} = '{dialect.null_type}'" + if isinstance(value, bool): + bool_str = dialect.true_value if value else dialect.false_value + if dialect.bool_type is None: + return f"{typeof} = '{bool_str}'" + return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')" + if isinstance(value, int): + bp = _bind(compiler, value, BigInteger(), **kw) + if dialect.int_guard: + return ( + f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} " + f"THEN CAST({extract} AS {dialect.int_cast}) END = {bp})" + ) + return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})" + if isinstance(value, float): + bp = _bind(compiler, value, Float(), **kw) + return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})" + bp = _bind(compiler, str(value), String(), **kw) + return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})" + + +@compiles(JsonMatch, "sqlite") +def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + path = f'$."{element.key}"' + typeof = f"json_type({col}, '{path}')" + extract = f"json_extract({col}, '{path}')" + return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw) + + +@compiles(JsonMatch, "postgresql") +def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + typeof = f"json_typeof({col} -> '{element.key}')" + extract = f"({col} ->> '{element.key}')" + return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw) + + +@compiles(JsonMatch, "mysql") +def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + if not validate_metadata_filter_key(element.key): + raise ValueError(f"Key escaped validation: {element.key!r}") + col = compiler.process(element.column, **kw) + path = f'$."{element.key}"' + typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))" + extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))" + return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw) + + +@compiles(JsonMatch) +def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str: + raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}") + + +def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch: + return JsonMatch(column, key, value) diff --git a/backend/packages/storage/store/repositories/__init__.py b/backend/packages/storage/store/repositories/__init__.py index 4b3f078e7..a7bc267fd 100644 --- a/backend/packages/storage/store/repositories/__init__.py +++ b/backend/packages/storage/store/repositories/__init__.py @@ -3,6 +3,7 @@ from store.repositories.contracts import ( FeedbackAggregate, FeedbackCreate, FeedbackRepositoryProtocol, + InvalidMetadataFilterError, Run, RunCreate, RunEvent, @@ -30,6 +31,7 @@ __all__ = [ "FeedbackAggregate", "FeedbackCreate", "FeedbackRepositoryProtocol", + "InvalidMetadataFilterError", "Run", "RunCreate", "RunEvent", diff --git a/backend/packages/storage/store/repositories/contracts/__init__.py b/backend/packages/storage/store/repositories/contracts/__init__.py index 4876c4a6b..8492cde3e 100644 --- a/backend/packages/storage/store/repositories/contracts/__init__.py +++ b/backend/packages/storage/store/repositories/contracts/__init__.py @@ -15,6 +15,7 @@ from store.repositories.contracts.run_event import ( RunEventRepositoryProtocol, ) from store.repositories.contracts.thread_meta import ( + InvalidMetadataFilterError, ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol, @@ -37,6 +38,7 @@ __all__ = [ "RunEventCreate", "RunEventRepositoryProtocol", "RunRepositoryProtocol", + "InvalidMetadataFilterError", "ThreadMeta", "ThreadMetaCreate", "ThreadMetaRepositoryProtocol", diff --git a/backend/packages/storage/store/repositories/contracts/thread_meta.py b/backend/packages/storage/store/repositories/contracts/thread_meta.py index de2d82b48..8222ca0e4 100644 --- a/backend/packages/storage/store/repositories/contracts/thread_meta.py +++ b/backend/packages/storage/store/repositories/contracts/thread_meta.py @@ -6,6 +6,10 @@ from typing import Any, Protocol from pydantic import BaseModel, ConfigDict, Field +class InvalidMetadataFilterError(ValueError): + """Raised when all client-supplied metadata filters are rejected.""" + + class ThreadMetaCreate(BaseModel): model_config = ConfigDict(extra="forbid") diff --git a/backend/packages/storage/store/repositories/db/run.py b/backend/packages/storage/store/repositories/db/run.py index 93e6e1d95..146ec1e46 100644 --- a/backend/packages/storage/store/repositories/db/run.py +++ b/backend/packages/storage/store/repositories/db/run.py @@ -153,9 +153,10 @@ class DbRunRepository(RunRepositoryProtocol): async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]: completed = RunModel.status.in_(("success", "error")) + model_expr = func.coalesce(RunModel.model_name, "unknown") stmt = ( select( - func.coalesce(RunModel.model_name, "unknown").label("model"), + model_expr.label("model"), func.count().label("runs"), func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"), func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"), @@ -165,7 +166,7 @@ class DbRunRepository(RunRepositoryProtocol): func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"), ) .where(RunModel.thread_id == thread_id, completed) - .group_by(func.coalesce(RunModel.model_name, "unknown")) + .group_by(model_expr) ) rows = (await self._session.execute(stmt)).all() diff --git a/backend/packages/storage/store/repositories/db/run_event.py b/backend/packages/storage/store/repositories/db/run_event.py index f18d17b77..9d22985c0 100644 --- a/backend/packages/storage/store/repositories/db/run_event.py +++ b/backend/packages/storage/store/repositories/db/run_event.py @@ -56,8 +56,10 @@ class DbRunEventRepository(RunEventRepositoryProtocol): seq_by_thread: dict[str, int] = {} for thread_id in thread_ids: max_seq = await self._session.scalar( - select(func.max(RunEventModel.seq)) + select(RunEventModel.seq) .where(RunEventModel.thread_id == thread_id) + .order_by(RunEventModel.seq.desc()) + .limit(1) .with_for_update() ) seq_by_thread[thread_id] = max_seq or 0 diff --git a/backend/packages/storage/store/repositories/db/thread_meta.py b/backend/packages/storage/store/repositories/db/thread_meta.py index f9fcf74d4..af06b2da9 100644 --- a/backend/packages/storage/store/repositories/db/thread_meta.py +++ b/backend/packages/storage/store/repositories/db/thread_meta.py @@ -1,13 +1,22 @@ from __future__ import annotations +import logging from typing import Any from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncSession -from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol +from store.persistence.json_compat import json_match +from store.repositories.contracts.thread_meta import ( + InvalidMetadataFilterError, + ThreadMeta, + ThreadMetaCreate, + ThreadMetaRepositoryProtocol, +) from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel +logger = logging.getLogger(__name__) + def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta: return ThreadMeta( @@ -87,10 +96,18 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol): if assistant_id is not None: stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id) if metadata: + applied = 0 for key, value in metadata.items(): - stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value)) + try: + stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value)) + applied += 1 + except (ValueError, TypeError) as exc: + logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc) + if applied == 0: + rejected_keys = ", ".join(sorted(str(key) for key in metadata)) + raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}") - stmt = stmt.order_by(ThreadMetaModel.created_time.desc()) + stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc()) stmt = stmt.limit(limit).offset(offset) result = await self._session.execute(stmt) diff --git a/backend/tests/test_storage_json_compat.py b/backend/tests/test_storage_json_compat.py new file mode 100644 index 000000000..1ca1645ee --- /dev/null +++ b/backend/tests/test_storage_json_compat.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import os +from pathlib import Path + +import pytest +from sqlalchemy import Column, MetaData, String, Table +from sqlalchemy.dialects import mysql, postgresql +from sqlalchemy.types import JSON + +os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml")) + +from store.persistence.json_compat import json_match + + +def _table(): + metadata = MetaData() + return Table("t", metadata, Column("data", JSON), Column("id", String)) + + +def test_storage_json_match_compiles_sqlite() -> None: + from sqlalchemy import create_engine + + table = _table() + dialect = create_engine("sqlite://").dialect + + assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ( + "json_type(t.data, '$.\"k\"') = 'null'" + ) + assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ( + "json_type(t.data, '$.\"k\"') = 'true'" + ) + + int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "= 'integer'" in int_sql + assert "CAST" in int_sql + + float_sql = str(json_match(table.c.data, "k", 3.14).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "IN ('integer', 'real')" in float_sql + assert "REAL" in float_sql + + +def test_storage_json_match_compiles_postgres() -> None: + table = _table() + dialect = postgresql.dialect() + + assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ( + "json_typeof(t.data -> 'k') = 'null'" + ) + assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ( + "(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')" + ) + + int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "CASE WHEN" in int_sql + assert "BIGINT" in int_sql + assert "'^-?[0-9]+$'" in int_sql + + +def test_storage_json_match_compiles_mysql() -> None: + table = _table() + dialect = mysql.dialect() + + null_sql = str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert null_sql == "JSON_TYPE(JSON_EXTRACT(t.data, '$.\"k\"')) = 'NULL'" + + bool_sql = str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "JSON_TYPE(JSON_EXTRACT" in bool_sql + assert "= 'BOOLEAN'" in bool_sql + assert "= 'true'" in bool_sql + + int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) + assert "= 'INTEGER'" in int_sql + assert "SIGNED" in int_sql + + +def test_storage_json_match_rejects_unsafe_keys_and_values() -> None: + table = _table() + + for bad_key in ["a.b", "bad;key", "with space", "", 42, None]: + with pytest.raises(ValueError, match="JsonMatch key must match"): + json_match(table.c.data, bad_key, "x") # type: ignore[arg-type] + + for bad_value in [[], {}, object()]: + with pytest.raises(TypeError, match="JsonMatch value must be"): + json_match(table.c.data, "k", bad_value) + + with pytest.raises(TypeError, match="out of signed 64-bit range"): + json_match(table.c.data, "k", 2**63) diff --git a/backend/tests/test_storage_repositories.py b/backend/tests/test_storage_repositories.py index d04563e55..0e66d934c 100644 --- a/backend/tests/test_storage_repositories.py +++ b/backend/tests/test_storage_repositories.py @@ -12,6 +12,7 @@ os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().pare from store.persistence import create_persistence_from_database_config from store.repositories import ( FeedbackCreate, + InvalidMetadataFilterError, RunCreate, RunEventCreate, ThreadMetaCreate, @@ -175,6 +176,77 @@ async def test_storage_thread_meta_repository_search_update_delete(tmp_path): await persistence.aclose() +@pytest.mark.anyio +async def test_storage_thread_meta_metadata_filters_are_type_safe(tmp_path): + persistence = await _make_persistence(tmp_path) + try: + async with persistence.session_factory() as session: + repo = build_thread_meta_repository(session) + await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-true", metadata={"value": True})) + await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-false", metadata={"value": False})) + await repo.create_thread_meta(ThreadMetaCreate(thread_id="int-one", metadata={"value": 1})) + await repo.create_thread_meta(ThreadMetaCreate(thread_id="null-value", metadata={"value": None})) + await repo.create_thread_meta(ThreadMetaCreate(thread_id="missing-value", metadata={"other": "x"})) + await session.commit() + + async with persistence.session_factory() as session: + repo = build_thread_meta_repository(session) + assert [row.thread_id for row in await repo.search_threads(metadata={"value": True})] == ["bool-true"] + assert [row.thread_id for row in await repo.search_threads(metadata={"value": False})] == ["bool-false"] + assert [row.thread_id for row in await repo.search_threads(metadata={"value": 1})] == ["int-one"] + assert [row.thread_id for row in await repo.search_threads(metadata={"value": None})] == ["null-value"] + finally: + await persistence.aclose() + + +@pytest.mark.anyio +async def test_storage_thread_meta_metadata_filters_paginate_after_sql_match(tmp_path): + persistence = await _make_persistence(tmp_path) + try: + async with persistence.session_factory() as session: + repo = build_thread_meta_repository(session) + for index in range(30): + metadata = {"target": "yes"} if index % 3 == 0 else {"target": "no"} + await repo.create_thread_meta(ThreadMetaCreate(thread_id=f"thread-{index:02d}", metadata=metadata)) + await session.commit() + + async with persistence.session_factory() as session: + repo = build_thread_meta_repository(session) + first_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=0) + second_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=3) + last_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=9) + + assert len(first_page) == 3 + assert len(second_page) == 3 + assert len(last_page) == 1 + assert {row.thread_id for row in first_page}.isdisjoint({row.thread_id for row in second_page}) + finally: + await persistence.aclose() + + +@pytest.mark.anyio +async def test_storage_thread_meta_metadata_filter_rejects_invalid_entries(tmp_path): + persistence = await _make_persistence(tmp_path) + try: + async with persistence.session_factory() as session: + repo = build_thread_meta_repository(session) + await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-1", metadata={"env": "prod"})) + await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-2", metadata={"env": "staging"})) + await session.commit() + + async with persistence.session_factory() as session: + repo = build_thread_meta_repository(session) + partial = await repo.search_threads(metadata={"env": "prod", "bad;key": "ignored"}) + assert [row.thread_id for row in partial] == ["thread-1"] + + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search_threads(metadata={"bad;key": "x"}) + with pytest.raises(InvalidMetadataFilterError, match="rejected"): + await repo.search_threads(metadata={"env": ["prod", "staging"]}) + finally: + await persistence.aclose() + + @pytest.mark.anyio async def test_storage_feedback_repository_lists_and_deletes(tmp_path): persistence = await _make_persistence(tmp_path)