mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
fix(storage): harden sql persistence compatibility
This commit is contained in:
@@ -0,0 +1,89 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, MetaData, String, Table
|
||||
from sqlalchemy.dialects import mysql, postgresql
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.persistence.json_compat import json_match
|
||||
|
||||
|
||||
def _table():
|
||||
metadata = MetaData()
|
||||
return Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_sqlite() -> None:
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
table = _table()
|
||||
dialect = create_engine("sqlite://").dialect
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"json_type(t.data, '$.\"k\"') = 'null'"
|
||||
)
|
||||
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"json_type(t.data, '$.\"k\"') = 'true'"
|
||||
)
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "= 'integer'" in int_sql
|
||||
assert "CAST" in int_sql
|
||||
|
||||
float_sql = str(json_match(table.c.data, "k", 3.14).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "IN ('integer', 'real')" in float_sql
|
||||
assert "REAL" in float_sql
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_postgres() -> None:
|
||||
table = _table()
|
||||
dialect = postgresql.dialect()
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"json_typeof(t.data -> 'k') = 'null'"
|
||||
)
|
||||
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == (
|
||||
"(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')"
|
||||
)
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "CASE WHEN" in int_sql
|
||||
assert "BIGINT" in int_sql
|
||||
assert "'^-?[0-9]+$'" in int_sql
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_mysql() -> None:
|
||||
table = _table()
|
||||
dialect = mysql.dialect()
|
||||
|
||||
null_sql = str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert null_sql == "JSON_TYPE(JSON_EXTRACT(t.data, '$.\"k\"')) = 'NULL'"
|
||||
|
||||
bool_sql = str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "JSON_TYPE(JSON_EXTRACT" in bool_sql
|
||||
assert "= 'BOOLEAN'" in bool_sql
|
||||
assert "= 'true'" in bool_sql
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "= 'INTEGER'" in int_sql
|
||||
assert "SIGNED" in int_sql
|
||||
|
||||
|
||||
def test_storage_json_match_rejects_unsafe_keys_and_values() -> None:
|
||||
table = _table()
|
||||
|
||||
for bad_key in ["a.b", "bad;key", "with space", "", 42, None]:
|
||||
with pytest.raises(ValueError, match="JsonMatch key must match"):
|
||||
json_match(table.c.data, bad_key, "x") # type: ignore[arg-type]
|
||||
|
||||
for bad_value in [[], {}, object()]:
|
||||
with pytest.raises(TypeError, match="JsonMatch value must be"):
|
||||
json_match(table.c.data, "k", bad_value)
|
||||
|
||||
with pytest.raises(TypeError, match="out of signed 64-bit range"):
|
||||
json_match(table.c.data, "k", 2**63)
|
||||
@@ -12,6 +12,7 @@ os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().pare
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import (
|
||||
FeedbackCreate,
|
||||
InvalidMetadataFilterError,
|
||||
RunCreate,
|
||||
RunEventCreate,
|
||||
ThreadMetaCreate,
|
||||
@@ -175,6 +176,77 @@ async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filters_are_type_safe(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-true", metadata={"value": True}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-false", metadata={"value": False}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="int-one", metadata={"value": 1}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="null-value", metadata={"value": None}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="missing-value", metadata={"other": "x"}))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": True})] == ["bool-true"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": False})] == ["bool-false"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": 1})] == ["int-one"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": None})] == ["null-value"]
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filters_paginate_after_sql_match(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
for index in range(30):
|
||||
metadata = {"target": "yes"} if index % 3 == 0 else {"target": "no"}
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id=f"thread-{index:02d}", metadata=metadata))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
first_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=0)
|
||||
second_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=3)
|
||||
last_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=9)
|
||||
|
||||
assert len(first_page) == 3
|
||||
assert len(second_page) == 3
|
||||
assert len(last_page) == 1
|
||||
assert {row.thread_id for row in first_page}.isdisjoint({row.thread_id for row in second_page})
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filter_rejects_invalid_entries(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-1", metadata={"env": "prod"}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-2", metadata={"env": "staging"}))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
partial = await repo.search_threads(metadata={"env": "prod", "bad;key": "ignored"})
|
||||
assert [row.thread_id for row in partial] == ["thread-1"]
|
||||
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search_threads(metadata={"bad;key": "x"})
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search_threads(metadata={"env": ["prod", "staging"]})
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
|
||||
Reference in New Issue
Block a user