mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
fix(storage): harden sql persistence compatibility
This commit is contained in:
@@ -11,7 +11,7 @@ from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _validate_mysql_driver(db_url: str) -> str:
|
||||
def _validate_mysql_driver(db_url: URL) -> str:
|
||||
url = make_url(db_url)
|
||||
driver = url.get_driver_name()
|
||||
|
||||
@@ -23,6 +23,10 @@ def _validate_mysql_driver(db_url: str) -> str:
|
||||
return driver
|
||||
|
||||
|
||||
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||
return db_url.render_as_string(hide_password=False)
|
||||
|
||||
|
||||
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
_validate_mysql_driver(db_url)
|
||||
|
||||
@@ -46,7 +50,7 @@ async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size:
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(db_url)
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
|
||||
@@ -10,6 +10,10 @@ from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||
return db_url.set(drivername="postgresql").render_as_string(hide_password=False)
|
||||
|
||||
|
||||
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
@@ -31,7 +35,7 @@ async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_si
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(db_url)
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
|
||||
@@ -0,0 +1,192 @@
|
||||
"""Dialect-aware JSON value matching for storage SQLAlchemy repositories."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import BigInteger, Float, String, bindparam
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.compiler import SQLCompiler
|
||||
from sqlalchemy.sql.expression import ColumnElement
|
||||
from sqlalchemy.sql.visitors import InternalTraversal
|
||||
from sqlalchemy.types import Boolean, TypeEngine
|
||||
|
||||
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
|
||||
|
||||
_INT64_MIN = -(2**63)
|
||||
_INT64_MAX = 2**63 - 1
|
||||
|
||||
|
||||
def validate_metadata_filter_key(key: object) -> bool:
|
||||
"""Return True when *key* is safe for JSON metadata filter SQL paths."""
|
||||
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
|
||||
|
||||
|
||||
def validate_metadata_filter_value(value: object) -> bool:
|
||||
"""Return True when *value* can be compiled into a portable JSON predicate."""
|
||||
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
|
||||
return False
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
return _INT64_MIN <= value <= _INT64_MAX
|
||||
return True
|
||||
|
||||
|
||||
class JsonMatch(ColumnElement[bool]):
|
||||
"""Dialect-portable ``column[key] == value`` for JSON columns."""
|
||||
|
||||
inherit_cache = True
|
||||
type = Boolean()
|
||||
_is_implicitly_boolean = True
|
||||
|
||||
_traverse_internals = [
|
||||
("column", InternalTraversal.dp_clauseelement),
|
||||
("key", InternalTraversal.dp_string),
|
||||
("value", InternalTraversal.dp_plain_obj),
|
||||
("value_type", InternalTraversal.dp_string),
|
||||
]
|
||||
|
||||
def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None:
|
||||
if not validate_metadata_filter_key(key):
|
||||
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
|
||||
if not validate_metadata_filter_value(value):
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
|
||||
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
|
||||
self.column = column
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.value_type = type(value).__qualname__
|
||||
super().__init__()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _Dialect:
|
||||
null_type: str
|
||||
num_types: tuple[str, ...]
|
||||
num_cast: str
|
||||
int_types: tuple[str, ...]
|
||||
int_cast: str
|
||||
int_guard: str | None
|
||||
string_type: str
|
||||
bool_type: str | None
|
||||
true_value: str
|
||||
false_value: str
|
||||
|
||||
|
||||
_SQLITE = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("integer", "real"),
|
||||
num_cast="REAL",
|
||||
int_types=("integer",),
|
||||
int_cast="INTEGER",
|
||||
int_guard=None,
|
||||
string_type="text",
|
||||
bool_type=None,
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
_POSTGRES = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("number",),
|
||||
num_cast="DOUBLE PRECISION",
|
||||
int_types=("number",),
|
||||
int_cast="BIGINT",
|
||||
int_guard="'^-?[0-9]+$'",
|
||||
string_type="string",
|
||||
bool_type="boolean",
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
_MYSQL = _Dialect(
|
||||
null_type="NULL",
|
||||
num_types=("INTEGER", "DOUBLE", "DECIMAL"),
|
||||
num_cast="DOUBLE",
|
||||
int_types=("INTEGER",),
|
||||
int_cast="SIGNED",
|
||||
int_guard=None,
|
||||
string_type="STRING",
|
||||
bool_type="BOOLEAN",
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
|
||||
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
|
||||
param = bindparam(None, value, type_=sa_type)
|
||||
return compiler.process(param, **kw)
|
||||
|
||||
|
||||
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
|
||||
if len(types) == 1:
|
||||
return f"{typeof} = '{types[0]}'"
|
||||
quoted = ", ".join(f"'{type_name}'" for type_name in types)
|
||||
return f"{typeof} IN ({quoted})"
|
||||
|
||||
|
||||
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
|
||||
if value is None:
|
||||
return f"{typeof} = '{dialect.null_type}'"
|
||||
if isinstance(value, bool):
|
||||
bool_str = dialect.true_value if value else dialect.false_value
|
||||
if dialect.bool_type is None:
|
||||
return f"{typeof} = '{bool_str}'"
|
||||
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
|
||||
if isinstance(value, int):
|
||||
bp = _bind(compiler, value, BigInteger(), **kw)
|
||||
if dialect.int_guard:
|
||||
return (
|
||||
f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} "
|
||||
f"THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
|
||||
)
|
||||
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
|
||||
if isinstance(value, float):
|
||||
bp = _bind(compiler, value, Float(), **kw)
|
||||
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
|
||||
bp = _bind(compiler, str(value), String(), **kw)
|
||||
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
|
||||
|
||||
|
||||
@compiles(JsonMatch, "sqlite")
|
||||
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
path = f'$."{element.key}"'
|
||||
typeof = f"json_type({col}, '{path}')"
|
||||
extract = f"json_extract({col}, '{path}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch, "postgresql")
|
||||
def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
typeof = f"json_typeof({col} -> '{element.key}')"
|
||||
extract = f"({col} ->> '{element.key}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch, "mysql")
|
||||
def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
path = f'$."{element.key}"'
|
||||
typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))"
|
||||
extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch)
|
||||
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}")
|
||||
|
||||
|
||||
def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch:
|
||||
return JsonMatch(column, key, value)
|
||||
@@ -3,6 +3,7 @@ from store.repositories.contracts import (
|
||||
FeedbackAggregate,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
InvalidMetadataFilterError,
|
||||
Run,
|
||||
RunCreate,
|
||||
RunEvent,
|
||||
@@ -30,6 +31,7 @@ __all__ = [
|
||||
"FeedbackAggregate",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"InvalidMetadataFilterError",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
|
||||
@@ -15,6 +15,7 @@ from store.repositories.contracts.run_event import (
|
||||
RunEventRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
InvalidMetadataFilterError,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
@@ -37,6 +38,7 @@ __all__ = [
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"InvalidMetadataFilterError",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
|
||||
@@ -6,6 +6,10 @@ from typing import Any, Protocol
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class InvalidMetadataFilterError(ValueError):
|
||||
"""Raised when all client-supplied metadata filters are rejected."""
|
||||
|
||||
|
||||
class ThreadMetaCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
@@ -153,9 +153,10 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = RunModel.status.in_(("success", "error"))
|
||||
model_expr = func.coalesce(RunModel.model_name, "unknown")
|
||||
stmt = (
|
||||
select(
|
||||
func.coalesce(RunModel.model_name, "unknown").label("model"),
|
||||
model_expr.label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
|
||||
@@ -165,7 +166,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(RunModel.thread_id == thread_id, completed)
|
||||
.group_by(func.coalesce(RunModel.model_name, "unknown"))
|
||||
.group_by(model_expr)
|
||||
)
|
||||
|
||||
rows = (await self._session.execute(stmt)).all()
|
||||
|
||||
@@ -56,8 +56,10 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
seq_by_thread: dict[str, int] = {}
|
||||
for thread_id in thread_ids:
|
||||
max_seq = await self._session.scalar(
|
||||
select(func.max(RunEventModel.seq))
|
||||
select(RunEventModel.seq)
|
||||
.where(RunEventModel.thread_id == thread_id)
|
||||
.order_by(RunEventModel.seq.desc())
|
||||
.limit(1)
|
||||
.with_for_update()
|
||||
)
|
||||
seq_by_thread[thread_id] = max_seq or 0
|
||||
|
||||
@@ -1,13 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol
|
||||
from store.persistence.json_compat import json_match
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
InvalidMetadataFilterError,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||
return ThreadMeta(
|
||||
@@ -87,10 +96,18 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
if assistant_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||
if metadata:
|
||||
applied = 0
|
||||
for key, value in metadata.items():
|
||||
stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value))
|
||||
try:
|
||||
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
|
||||
applied += 1
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
|
||||
if applied == 0:
|
||||
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
|
||||
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
|
||||
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc())
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
|
||||
Reference in New Issue
Block a user