fix(persistence): address new Copilot review comments

- feedback.py: validate thread_id/run_id before deleting feedback
- jsonl.py: add path traversal protection with ID validation
- run_repo.py: parse `before` to datetime for PostgreSQL compat
- thread_meta_repo.py: fix pagination when metadata filter is active
- database_config.py: use resolve_path for sqlite_dir consistency

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-06 21:46:54 +08:00
parent 0ecc2f954c
commit 5ead75d289
5 changed files with 46 additions and 9 deletions
+6
View File
@@ -115,6 +115,12 @@ async def delete_feedback(
) -> dict[str, bool]: ) -> dict[str, bool]:
"""Delete a feedback record.""" """Delete a feedback record."""
feedback_repo = get_feedback_repo(request) feedback_repo = get_feedback_repo(request)
# Verify feedback belongs to the specified thread/run before deleting
existing = await feedback_repo.get(feedback_id)
if existing is None:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
deleted = await feedback_repo.delete(feedback_id) deleted = await feedback_repo.delete(feedback_id)
if not deleted: if not deleted:
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found") raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
@@ -65,12 +65,18 @@ class DatabaseConfig(BaseModel):
@property @property
def checkpointer_sqlite_path(self) -> str: def checkpointer_sqlite_path(self) -> str:
"""SQLite file path for the LangGraph checkpointer.""" """SQLite file path for the LangGraph checkpointer."""
return os.path.join(self.sqlite_dir, "checkpoints.db") from deerflow.config.paths import resolve_path
resolved_dir = str(resolve_path(self.sqlite_dir))
return os.path.join(resolved_dir, "checkpoints.db")
@property @property
def app_sqlite_path(self) -> str: def app_sqlite_path(self) -> str:
"""SQLite file path for application ORM data.""" """SQLite file path for application ORM data."""
return os.path.join(self.sqlite_dir, "app.db") from deerflow.config.paths import resolve_path
resolved_dir = str(resolve_path(self.sqlite_dir))
return os.path.join(resolved_dir, "app.db")
@property @property
def app_sqlalchemy_url(self) -> str: def app_sqlalchemy_url(self) -> str:
@@ -126,8 +126,13 @@ class RunRepository(RunStore):
await session.commit() await session.commit()
async def list_pending(self, *, before=None): async def list_pending(self, *, before=None):
now = before or datetime.now(UTC).isoformat() if before is None:
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= now).order_by(RunRow.created_at.asc()) before_dt = datetime.now(UTC)
elif isinstance(before, datetime):
before_dt = before
else:
before_dt = datetime.fromisoformat(before)
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= before_dt).order_by(RunRow.created_at.asc())
async with self._sf() as session: async with self._sf() as session:
result = await session.execute(stmt) result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()] return [self._row_to_dict(r) for r in result.scalars()]
@@ -88,14 +88,22 @@ class ThreadMetaRepository(ThreadMetaStore):
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc()) stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
if status: if status:
stmt = stmt.where(ThreadMetaRow.status == status) stmt = stmt.where(ThreadMetaRow.status == status)
stmt = stmt.limit(limit).offset(offset)
async with self._sf() as session:
result = await session.execute(stmt)
rows = [self._row_to_dict(r) for r in result.scalars()]
if metadata: if metadata:
# When metadata filter is active, fetch a larger window and filter
# in Python. TODO(Phase 2): use JSON DB operators (Postgres @>,
# SQLite json_extract) for server-side filtering.
stmt = stmt.limit(limit * 5 + offset)
async with self._sf() as session:
result = await session.execute(stmt)
rows = [self._row_to_dict(r) for r in result.scalars()]
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())] rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
return rows return rows[offset : offset + limit]
else:
stmt = stmt.limit(limit).offset(offset)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def update_display_name(self, thread_id: str, display_name: str) -> None: async def update_display_name(self, thread_id: str, display_name: str) -> None:
"""Update the display_name (title) for a thread.""" """Update the display_name (title) for a thread."""
@@ -15,6 +15,7 @@ from __future__ import annotations
import json import json
import logging import logging
import re
from datetime import UTC, datetime from datetime import UTC, datetime
from pathlib import Path from pathlib import Path
@@ -22,16 +23,27 @@ from deerflow.runtime.events.store.base import RunEventStore
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SAFE_ID_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$")
class JsonlRunEventStore(RunEventStore): class JsonlRunEventStore(RunEventStore):
def __init__(self, base_dir: str | Path | None = None): def __init__(self, base_dir: str | Path | None = None):
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow") self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
@staticmethod
def _validate_id(value: str, label: str) -> str:
"""Validate that an ID is safe for use in filesystem paths."""
if not value or not _SAFE_ID_PATTERN.match(value):
raise ValueError(f"Invalid {label}: must be alphanumeric/dash/underscore, got {value!r}")
return value
def _thread_dir(self, thread_id: str) -> Path: def _thread_dir(self, thread_id: str) -> Path:
self._validate_id(thread_id, "thread_id")
return self._base_dir / "threads" / thread_id / "runs" return self._base_dir / "threads" / thread_id / "runs"
def _run_file(self, thread_id: str, run_id: str) -> Path: def _run_file(self, thread_id: str, run_id: str) -> Path:
self._validate_id(run_id, "run_id")
return self._thread_dir(thread_id) / f"{run_id}.jsonl" return self._thread_dir(thread_id) / f"{run_id}.jsonl"
def _next_seq(self, thread_id: str) -> int: def _next_seq(self, thread_id: str) -> int: