mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 00:45:57 +00:00
fix(persistence): address new Copilot review comments
- feedback.py: validate thread_id/run_id before deleting feedback - jsonl.py: add path traversal protection with ID validation - run_repo.py: parse `before` to datetime for PostgreSQL compat - thread_meta_repo.py: fix pagination when metadata filter is active - database_config.py: use resolve_path for sqlite_dir consistency Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -115,6 +115,12 @@ async def delete_feedback(
|
|||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Delete a feedback record."""
|
"""Delete a feedback record."""
|
||||||
feedback_repo = get_feedback_repo(request)
|
feedback_repo = get_feedback_repo(request)
|
||||||
|
# Verify feedback belongs to the specified thread/run before deleting
|
||||||
|
existing = await feedback_repo.get(feedback_id)
|
||||||
|
if existing is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||||
|
if existing.get("thread_id") != thread_id or existing.get("run_id") != run_id:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found in run {run_id}")
|
||||||
deleted = await feedback_repo.delete(feedback_id)
|
deleted = await feedback_repo.delete(feedback_id)
|
||||||
if not deleted:
|
if not deleted:
|
||||||
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
raise HTTPException(status_code=404, detail=f"Feedback {feedback_id} not found")
|
||||||
|
|||||||
@@ -65,12 +65,18 @@ class DatabaseConfig(BaseModel):
|
|||||||
@property
|
@property
|
||||||
def checkpointer_sqlite_path(self) -> str:
|
def checkpointer_sqlite_path(self) -> str:
|
||||||
"""SQLite file path for the LangGraph checkpointer."""
|
"""SQLite file path for the LangGraph checkpointer."""
|
||||||
return os.path.join(self.sqlite_dir, "checkpoints.db")
|
from deerflow.config.paths import resolve_path
|
||||||
|
|
||||||
|
resolved_dir = str(resolve_path(self.sqlite_dir))
|
||||||
|
return os.path.join(resolved_dir, "checkpoints.db")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_sqlite_path(self) -> str:
|
def app_sqlite_path(self) -> str:
|
||||||
"""SQLite file path for application ORM data."""
|
"""SQLite file path for application ORM data."""
|
||||||
return os.path.join(self.sqlite_dir, "app.db")
|
from deerflow.config.paths import resolve_path
|
||||||
|
|
||||||
|
resolved_dir = str(resolve_path(self.sqlite_dir))
|
||||||
|
return os.path.join(resolved_dir, "app.db")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def app_sqlalchemy_url(self) -> str:
|
def app_sqlalchemy_url(self) -> str:
|
||||||
|
|||||||
@@ -126,8 +126,13 @@ class RunRepository(RunStore):
|
|||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def list_pending(self, *, before=None):
|
async def list_pending(self, *, before=None):
|
||||||
now = before or datetime.now(UTC).isoformat()
|
if before is None:
|
||||||
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= now).order_by(RunRow.created_at.asc())
|
before_dt = datetime.now(UTC)
|
||||||
|
elif isinstance(before, datetime):
|
||||||
|
before_dt = before
|
||||||
|
else:
|
||||||
|
before_dt = datetime.fromisoformat(before)
|
||||||
|
stmt = select(RunRow).where(RunRow.status == "pending", RunRow.created_at <= before_dt).order_by(RunRow.created_at.asc())
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
result = await session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
return [self._row_to_dict(r) for r in result.scalars()]
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|||||||
@@ -88,14 +88,22 @@ class ThreadMetaRepository(ThreadMetaStore):
|
|||||||
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
stmt = select(ThreadMetaRow).order_by(ThreadMetaRow.updated_at.desc())
|
||||||
if status:
|
if status:
|
||||||
stmt = stmt.where(ThreadMetaRow.status == status)
|
stmt = stmt.where(ThreadMetaRow.status == status)
|
||||||
stmt = stmt.limit(limit).offset(offset)
|
|
||||||
async with self._sf() as session:
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
rows = [self._row_to_dict(r) for r in result.scalars()]
|
|
||||||
|
|
||||||
if metadata:
|
if metadata:
|
||||||
|
# When metadata filter is active, fetch a larger window and filter
|
||||||
|
# in Python. TODO(Phase 2): use JSON DB operators (Postgres @>,
|
||||||
|
# SQLite json_extract) for server-side filtering.
|
||||||
|
stmt = stmt.limit(limit * 5 + offset)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
rows = [self._row_to_dict(r) for r in result.scalars()]
|
||||||
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
|
rows = [r for r in rows if all(r.get("metadata", {}).get(k) == v for k, v in metadata.items())]
|
||||||
return rows
|
return rows[offset : offset + limit]
|
||||||
|
else:
|
||||||
|
stmt = stmt.limit(limit).offset(offset)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return [self._row_to_dict(r) for r in result.scalars()]
|
||||||
|
|
||||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||||
"""Update the display_name (title) for a thread."""
|
"""Update the display_name (title) for a thread."""
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -22,16 +23,27 @@ from deerflow.runtime.events.store.base import RunEventStore
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SAFE_ID_PATTERN = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||||
|
|
||||||
|
|
||||||
class JsonlRunEventStore(RunEventStore):
|
class JsonlRunEventStore(RunEventStore):
|
||||||
def __init__(self, base_dir: str | Path | None = None):
|
def __init__(self, base_dir: str | Path | None = None):
|
||||||
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
|
self._base_dir = Path(base_dir) if base_dir else Path(".deer-flow")
|
||||||
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
|
self._seq_counters: dict[str, int] = {} # thread_id -> current max seq
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _validate_id(value: str, label: str) -> str:
|
||||||
|
"""Validate that an ID is safe for use in filesystem paths."""
|
||||||
|
if not value or not _SAFE_ID_PATTERN.match(value):
|
||||||
|
raise ValueError(f"Invalid {label}: must be alphanumeric/dash/underscore, got {value!r}")
|
||||||
|
return value
|
||||||
|
|
||||||
def _thread_dir(self, thread_id: str) -> Path:
|
def _thread_dir(self, thread_id: str) -> Path:
|
||||||
|
self._validate_id(thread_id, "thread_id")
|
||||||
return self._base_dir / "threads" / thread_id / "runs"
|
return self._base_dir / "threads" / thread_id / "runs"
|
||||||
|
|
||||||
def _run_file(self, thread_id: str, run_id: str) -> Path:
|
def _run_file(self, thread_id: str, run_id: str) -> Path:
|
||||||
|
self._validate_id(run_id, "run_id")
|
||||||
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
|
return self._thread_dir(thread_id) / f"{run_id}.jsonl"
|
||||||
|
|
||||||
def _next_seq(self, thread_id: str) -> int:
|
def _next_seq(self, thread_id: str) -> int:
|
||||||
|
|||||||
Reference in New Issue
Block a user