fix(persistence): address 22 review comments from CodeQL, Copilot, and Code Quality

Bug fixes:
- Sanitize log params to prevent log injection (CodeQL)
- Reset threads_meta.status to idle/error when run completes
- Attach messages only to latest checkpoint in /history response
- Write threads_meta on POST /threads so new threads appear in search

Lint fixes:
- Remove unused imports (journal.py, migrations/env.py, test_converters.py)
- Convert lambda to named function (engine.py, Ruff E731)
- Remove unused logger definitions in repos (Ruff F841)
- Add logging to JSONL decode errors and empty except blocks
- Separate assert side-effects in tests (CodeQL)
- Remove unused local variables in tests (Ruff F841)
- Fix max_trace_content truncation to use byte length, not char length

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
rayhpeng
2026-04-05 22:49:26 +08:00
parent 32f69674a5
commit b94383c93a
15 changed files with 94 additions and 55 deletions
+39 -23
View File
@@ -35,6 +35,11 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/api/threads", tags=["threads"]) router = APIRouter(prefix="/api/threads", tags=["threads"])
def _sanitize_log_param(value: str) -> str:
"""Strip control characters to prevent log injection."""
return value.replace("\n", "").replace("\r", "").replace("\x00", "")
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Response / request models # Response / request models
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -136,13 +141,13 @@ def _delete_thread_data(thread_id: str, paths: Paths | None = None) -> ThreadDel
raise HTTPException(status_code=422, detail=str(exc)) from exc raise HTTPException(status_code=422, detail=str(exc)) from exc
except FileNotFoundError: except FileNotFoundError:
# Not critical — thread data may not exist on disk # Not critical — thread data may not exist on disk
logger.debug("No local thread data to delete for %s", thread_id) logger.debug("No local thread data to delete for %s", _sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}") return ThreadDeleteResponse(success=True, message=f"No local data for {thread_id}")
except Exception as exc: except Exception as exc:
logger.exception("Failed to delete thread data for %s", thread_id) logger.exception("Failed to delete thread data for %s", _sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc raise HTTPException(status_code=500, detail="Failed to delete local thread data.") from exc
logger.info("Deleted local thread data for %s", thread_id) logger.info("Deleted local thread data for %s", _sanitize_log_param(thread_id))
return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}") return ThreadDeleteResponse(success=True, message=f"Deleted local thread data for {thread_id}")
@@ -231,7 +236,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
try: try:
await store.adelete(THREADS_NS, thread_id) await store.adelete(THREADS_NS, thread_id)
except Exception: except Exception:
logger.debug("Could not delete store record for thread %s (not critical)", thread_id) logger.debug("Could not delete store record for thread %s (not critical)", _sanitize_log_param(thread_id))
# Remove checkpoints (best-effort) # Remove checkpoints (best-effort)
checkpointer = getattr(request.app.state, "checkpointer", None) checkpointer = getattr(request.app.state, "checkpointer", None)
@@ -240,7 +245,7 @@ async def delete_thread_data(thread_id: str, request: Request) -> ThreadDeleteRe
if hasattr(checkpointer, "adelete_thread"): if hasattr(checkpointer, "adelete_thread"):
await checkpointer.adelete_thread(thread_id) await checkpointer.adelete_thread(thread_id)
except Exception: except Exception:
logger.debug("Could not delete checkpoints for thread %s (not critical)", thread_id) logger.debug("Could not delete checkpoints for thread %s (not critical)", _sanitize_log_param(thread_id))
return response return response
@@ -284,7 +289,7 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
}, },
) )
except Exception: except Exception:
logger.exception("Failed to write thread %s to store", thread_id) logger.exception("Failed to write thread %s to store", _sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread") raise HTTPException(status_code=500, detail="Failed to create thread")
# Write an empty checkpoint so state endpoints work immediately # Write an empty checkpoint so state endpoints work immediately
@@ -302,10 +307,24 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
} }
await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {}) await checkpointer.aput(config, empty_checkpoint(), ckpt_metadata, {})
except Exception: except Exception:
logger.exception("Failed to create checkpoint for thread %s", thread_id) logger.exception("Failed to create checkpoint for thread %s", _sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to create thread") raise HTTPException(status_code=500, detail="Failed to create thread")
logger.info("Thread created: %s", thread_id) # Write thread_meta so the thread appears in /threads/search immediately
from app.gateway.deps import get_thread_meta_repo
thread_meta_repo = get_thread_meta_repo(request)
if thread_meta_repo is not None:
try:
await thread_meta_repo.create(
thread_id,
assistant_id=getattr(body, "assistant_id", None),
metadata=body.metadata,
)
except Exception:
logger.debug("Failed to upsert thread_meta on create for %s (non-fatal)", _sanitize_log_param(thread_id))
logger.info("Thread created: %s", _sanitize_log_param(thread_id))
return ThreadResponse( return ThreadResponse(
thread_id=thread_id, thread_id=thread_id,
status="idle", status="idle",
@@ -372,7 +391,7 @@ async def patch_thread(thread_id: str, body: ThreadPatchRequest, request: Reques
try: try:
await _store_put(store, updated) await _store_put(store, updated)
except Exception: except Exception:
logger.exception("Failed to patch thread %s", thread_id) logger.exception("Failed to patch thread %s", _sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread") raise HTTPException(status_code=500, detail="Failed to update thread")
return ThreadResponse( return ThreadResponse(
@@ -404,7 +423,7 @@ async def get_thread(thread_id: str, request: Request) -> ThreadResponse:
try: try:
checkpoint_tuple = await checkpointer.aget_tuple(config) checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception: except Exception:
logger.exception("Failed to get checkpoint for thread %s", thread_id) logger.exception("Failed to get checkpoint for thread %s", _sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread") raise HTTPException(status_code=500, detail="Failed to get thread")
if record is None and checkpoint_tuple is None: if record is None and checkpoint_tuple is None:
@@ -452,7 +471,7 @@ async def get_thread_state(thread_id: str, request: Request) -> ThreadStateRespo
try: try:
checkpoint_tuple = await checkpointer.aget_tuple(config) checkpoint_tuple = await checkpointer.aget_tuple(config)
except Exception: except Exception:
logger.exception("Failed to get state for thread %s", thread_id) logger.exception("Failed to get state for thread %s", _sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state") raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None: if checkpoint_tuple is None:
@@ -514,7 +533,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
try: try:
checkpoint_tuple = await checkpointer.aget_tuple(read_config) checkpoint_tuple = await checkpointer.aget_tuple(read_config)
except Exception: except Exception:
logger.exception("Failed to get state for thread %s", thread_id) logger.exception("Failed to get state for thread %s", _sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread state") raise HTTPException(status_code=500, detail="Failed to get thread state")
if checkpoint_tuple is None: if checkpoint_tuple is None:
@@ -548,7 +567,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
try: try:
new_config = await checkpointer.aput(write_config, checkpoint, metadata, {}) new_config = await checkpointer.aput(write_config, checkpoint, metadata, {})
except Exception: except Exception:
logger.exception("Failed to update state for thread %s", thread_id) logger.exception("Failed to update state for thread %s", _sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to update thread state") raise HTTPException(status_code=500, detail="Failed to update thread state")
new_checkpoint_id: str | None = None new_checkpoint_id: str | None = None
@@ -560,7 +579,7 @@ async def update_thread_state(thread_id: str, body: ThreadStateUpdateRequest, re
try: try:
await _store_upsert(store, thread_id, values={"title": body.values["title"]}) await _store_upsert(store, thread_id, values={"title": body.values["title"]})
except Exception: except Exception:
logger.debug("Failed to sync title to store for thread %s (non-fatal)", thread_id) logger.debug("Failed to sync title to store for thread %s (non-fatal)", _sanitize_log_param(thread_id))
return ThreadStateResponse( return ThreadStateResponse(
values=serialize_channel_values(channel_values), values=serialize_channel_values(channel_values),
@@ -594,16 +613,12 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
try: try:
all_messages = await event_store.list_messages(thread_id, limit=10_000) all_messages = await event_store.list_messages(thread_id, limit=10_000)
except Exception: except Exception:
logger.warning("Failed to load messages from event store for thread %s", thread_id, exc_info=True) logger.warning("Failed to load messages from event store for thread %s", _sanitize_log_param(thread_id), exc_info=True)
all_messages = [] all_messages = []
# Group messages by run_id for per-checkpoint assembly
messages_by_run: dict[str, list[dict]] = {}
for msg in all_messages:
run_id = msg.get("run_id", "")
messages_by_run.setdefault(run_id, []).append(msg.get("content", {}))
entries: list[HistoryEntry] = [] entries: list[HistoryEntry] = []
is_latest_checkpoint = True
try: try:
async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit): async for checkpoint_tuple in checkpointer.alist(config, limit=body.limit):
ckpt_config = getattr(checkpoint_tuple, "config", {}) ckpt_config = getattr(checkpoint_tuple, "config", {})
@@ -625,9 +640,10 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
if thread_data := channel_values.get("thread_data"): if thread_data := channel_values.get("thread_data"):
values["thread_data"] = thread_data values["thread_data"] = thread_data
# Attach all messages from event store (not just this checkpoint's run) # Attach all messages only to the latest (first) checkpoint entry
if all_messages: if is_latest_checkpoint and all_messages:
values["messages"] = [m.get("content", {}) for m in all_messages] values["messages"] = [m.get("content", {}) for m in all_messages]
is_latest_checkpoint = False
# Derive next tasks # Derive next tasks
tasks_raw = getattr(checkpoint_tuple, "tasks", []) or [] tasks_raw = getattr(checkpoint_tuple, "tasks", []) or []
@@ -650,7 +666,7 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
) )
) )
except Exception: except Exception:
logger.exception("Failed to get history for thread %s", thread_id) logger.exception("Failed to get history for thread %s", _sanitize_log_param(thread_id))
raise HTTPException(status_code=500, detail="Failed to get thread history") raise HTTPException(status_code=500, detail="Failed to get thread history")
return entries return entries
+3 -2
View File
@@ -18,6 +18,7 @@ from fastapi import HTTPException, Request
from langchain_core.messages import HumanMessage from langchain_core.messages import HumanMessage
from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo from app.gateway.deps import get_checkpointer, get_run_event_store, get_run_manager, get_run_store, get_store, get_stream_bridge, get_thread_meta_repo
from app.gateway.routers.threads import _sanitize_log_param
from deerflow.runtime import ( from deerflow.runtime import (
END_SENTINEL, END_SENTINEL,
HEARTBEAT_SENTINEL, HEARTBEAT_SENTINEL,
@@ -184,7 +185,7 @@ async def _upsert_thread_in_store(store, thread_id: str, metadata: dict | None)
try: try:
await _store_upsert(store, thread_id, metadata=metadata) await _store_upsert(store, thread_id, metadata=metadata)
except Exception: except Exception:
logger.warning("Failed to upsert thread %s in store (non-fatal)", thread_id) logger.warning("Failed to upsert thread %s in store (non-fatal)", _sanitize_log_param(thread_id))
async def _sync_thread_title_after_run( async def _sync_thread_title_after_run(
@@ -312,7 +313,7 @@ async def start_run(
else: else:
await thread_meta_repo.update_status(thread_id, "running") await thread_meta_repo.update_status(thread_id, "running")
except Exception: except Exception:
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", thread_id) logger.warning("Failed to upsert thread_meta for %s (non-fatal)", _sanitize_log_param(thread_id))
agent_factory = resolve_agent_factory(body.assistant_id) agent_factory = resolve_agent_factory(body.assistant_id)
graph_input = normalize_input(body.input) graph_input = normalize_input(body.input)
@@ -10,13 +10,15 @@ None and fall back to in-memory implementations.
from __future__ import annotations from __future__ import annotations
import logging
import json import json
import logging
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
_json_serializer = lambda obj: json.dumps(obj, ensure_ascii=False)
def _json_serializer(obj: object) -> str:
"""JSON serializer with ensure_ascii=False for Chinese character support."""
return json.dumps(obj, ensure_ascii=False)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -106,7 +108,9 @@ async def init_engine(
try: try:
import deerflow.persistence.models # noqa: F401 import deerflow.persistence.models # noqa: F401
except ImportError: except ImportError:
pass # Models package not yet available — tables won't be auto-created.
# This is expected during initial scaffolding or minimal installs.
logger.debug("deerflow.persistence.models not found; skipping auto-create tables")
try: try:
async with _engine.begin() as conn: async with _engine.begin() as conn:
@@ -8,6 +8,7 @@ have their own schema lifecycle and must not be touched by Alembic.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging
from logging.config import fileConfig from logging.config import fileConfig
from alembic import context from alembic import context
@@ -17,9 +18,13 @@ from deerflow.persistence.base import Base
# Import all models so metadata is populated. # Import all models so metadata is populated.
try: try:
import deerflow.persistence.models # noqa: F401 import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
except ImportError: except ImportError:
pass # Models not available — migration will work with existing metadata only.
logging.getLogger(__name__).warning(
"Could not import deerflow.persistence.models; "
"Alembic may not detect all tables"
)
config = context.config config = context.config
if config.config_file_name is not None: if config.config_file_name is not None:
@@ -5,7 +5,6 @@ Each method acquires its own short-lived session.
from __future__ import annotations from __future__ import annotations
import logging
import uuid import uuid
from datetime import UTC, datetime from datetime import UTC, datetime
@@ -14,8 +13,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.feedback import FeedbackRow from deerflow.persistence.models.feedback import FeedbackRow
logger = logging.getLogger(__name__)
class FeedbackRepository: class FeedbackRepository:
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
@@ -8,7 +8,6 @@ minutes -- we don't hold connections across long execution.
from __future__ import annotations from __future__ import annotations
import json import json
import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
@@ -18,8 +17,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run import RunRow from deerflow.persistence.models.run import RunRow
from deerflow.runtime.runs.store.base import RunStore from deerflow.runtime.runs.store.base import RunStore
logger = logging.getLogger(__name__)
class RunRepository(RunStore): class RunRepository(RunStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
@@ -2,7 +2,6 @@
from __future__ import annotations from __future__ import annotations
import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from typing import Any from typing import Any
@@ -11,8 +10,6 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.thread_meta import ThreadMetaRow from deerflow.persistence.models.thread_meta import ThreadMetaRow
logger = logging.getLogger(__name__)
class ThreadMetaRepository: class ThreadMetaRepository:
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None: def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
@@ -7,6 +7,7 @@ at ``max_trace_content`` bytes to avoid bloating the database.
from __future__ import annotations from __future__ import annotations
import json import json
import logging
from datetime import UTC, datetime from datetime import UTC, datetime
from sqlalchemy import delete, func, select from sqlalchemy import delete, func, select
@@ -15,6 +16,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow from deerflow.persistence.models.run_event import RunEventRow
from deerflow.runtime.events.store.base import RunEventStore from deerflow.runtime.events.store.base import RunEventStore
logger = logging.getLogger(__name__)
class DbRunEventStore(RunEventStore): class DbRunEventStore(RunEventStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240): def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
@@ -35,15 +38,19 @@ class DbRunEventStore(RunEventStore):
try: try:
d["content"] = json.loads(raw) d["content"] = json.loads(raw)
except (json.JSONDecodeError, ValueError): except (json.JSONDecodeError, ValueError):
pass # Content looked like JSON (content_is_dict flag) but failed to parse;
# keep the raw string as-is.
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
return d return d
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]: def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
if category == "trace": if category == "trace":
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
if len(text) > self._max_trace_content: encoded = text.encode("utf-8")
content = text[: self._max_trace_content] if len(encoded) > self._max_trace_content:
metadata = {**(metadata or {}), "content_truncated": True} # Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
content = encoded[: self._max_trace_content].decode("utf-8", errors="ignore")
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
return content, metadata or {} return content, metadata or {}
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None):
@@ -51,6 +51,7 @@ class JsonlRunEventStore(RunEventStore):
record = json.loads(line) record = json.loads(line)
max_seq = max(max_seq, record.get("seq", 0)) max_seq = max(max_seq, record.get("seq", 0))
except json.JSONDecodeError: except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", f)
continue continue
self._seq_counters[thread_id] = max_seq self._seq_counters[thread_id] = max_seq
@@ -73,6 +74,7 @@ class JsonlRunEventStore(RunEventStore):
try: try:
events.append(json.loads(line)) events.append(json.loads(line))
except json.JSONDecodeError: except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", f)
continue continue
events.sort(key=lambda e: e.get("seq", 0)) events.sort(key=lambda e: e.get("seq", 0))
return events return events
@@ -89,6 +91,7 @@ class JsonlRunEventStore(RunEventStore):
try: try:
events.append(json.loads(line)) events.append(json.loads(line))
except json.JSONDecodeError: except json.JSONDecodeError:
logger.debug("Skipping malformed JSONL line in %s", path)
continue continue
events.sort(key=lambda e: e.get("seq", 0)) events.sort(key=lambda e: e.get("seq", 0))
return events return events
@@ -135,7 +135,7 @@ class RunJournal(BaseCallbackHandler):
self._llm_start_times[str(run_id)] = time.monotonic() self._llm_start_times[str(run_id)] = time.monotonic()
def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None: def on_llm_end(self, response: Any, *, run_id: UUID, **kwargs: Any) -> None:
from deerflow.runtime.converters import langchain_to_openai_completion, langchain_to_openai_message from deerflow.runtime.converters import langchain_to_openai_completion
try: try:
message = response.generations[0][0].message message = response.generations[0][0].message
@@ -17,7 +17,10 @@ from __future__ import annotations
import asyncio import asyncio
import logging import logging
from typing import Any, Literal from typing import TYPE_CHECKING, Any, Literal
if TYPE_CHECKING:
from langchain_core.messages import HumanMessage
from deerflow.runtime.serialization import serialize from deerflow.runtime.serialization import serialize
from deerflow.runtime.stream_bridge import StreamBridge from deerflow.runtime.stream_bridge import StreamBridge
@@ -273,6 +276,14 @@ async def run_agent(
except Exception: except Exception:
logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id) logger.debug("Failed to sync title for thread %s (non-fatal)", thread_id)
# Update threads_meta status based on run outcome
if thread_meta_repo is not None:
try:
final_status = "idle" if record.status == RunStatus.success else record.status.value
await thread_meta_repo.update_status(thread_id, final_status)
except Exception:
logger.debug("Failed to update thread_meta status for %s (non-fatal)", thread_id)
await bridge.publish_end(run_id) await bridge.publish_end(run_id)
asyncio.create_task(bridge.cleanup(run_id, delay=60)) asyncio.create_task(bridge.cleanup(run_id, delay=60))
@@ -294,7 +305,7 @@ def _lg_mode_to_sse_event(mode: str) -> str:
return mode return mode
def _extract_human_message(graph_input: dict) -> "HumanMessage | None": def _extract_human_message(graph_input: dict) -> HumanMessage | None:
"""Extract or construct a HumanMessage from graph_input for event recording. """Extract or construct a HumanMessage from graph_input for event recording.
Returns a LangChain HumanMessage so callers can use .model_dump() to get Returns a LangChain HumanMessage so callers can use .model_dump() to get
-3
View File
@@ -5,10 +5,7 @@ from __future__ import annotations
import json import json
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest
from deerflow.runtime.converters import ( from deerflow.runtime.converters import (
_infer_finish_reason,
langchain_messages_to_openai, langchain_messages_to_openai,
langchain_to_openai_completion, langchain_to_openai_completion,
langchain_to_openai_message, langchain_to_openai_message,
+4 -2
View File
@@ -117,14 +117,16 @@ class TestFeedbackRepository:
async def test_delete(self, tmp_path): async def test_delete(self, tmp_path):
repo = await _make_feedback_repo(tmp_path) repo = await _make_feedback_repo(tmp_path)
created = await repo.create(run_id="r1", thread_id="t1", rating=1) created = await repo.create(run_id="r1", thread_id="t1", rating=1)
assert await repo.delete(created["feedback_id"]) is True deleted = await repo.delete(created["feedback_id"])
assert deleted is True
assert await repo.get(created["feedback_id"]) is None assert await repo.get(created["feedback_id"]) is None
await _cleanup() await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
async def test_delete_nonexistent(self, tmp_path): async def test_delete_nonexistent(self, tmp_path):
repo = await _make_feedback_repo(tmp_path) repo = await _make_feedback_repo(tmp_path)
assert await repo.delete("nonexistent") is False deleted = await repo.delete("nonexistent")
assert deleted is False
await _cleanup() await _cleanup()
@pytest.mark.anyio @pytest.mark.anyio
+3 -1
View File
@@ -225,6 +225,8 @@ class TestEngineLifecycle:
pytest.skip("asyncpg is installed -- cannot test missing-dep path") pytest.skip("asyncpg is installed -- cannot test missing-dep path")
except ImportError: except ImportError:
pass # asyncpg is not installed — this is the expected state for this test.
# We proceed to verify that init_engine raises an actionable ImportError.
pass # noqa: S110 — intentionally ignored
with pytest.raises(ImportError, match="uv sync --extra postgres"): with pytest.raises(ImportError, match="uv sync --extra postgres"):
await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x") await init_engine("postgres", url="postgresql+asyncpg://x:x@localhost/x")
+2 -2
View File
@@ -456,7 +456,7 @@ class TestDictContentFlag:
sf = get_session_factory() sf = get_session_factory()
store = DbRunEventStore(sf) store = DbRunEventStore(sf)
record = await store.put( await store.put(
thread_id="t1", thread_id="t1",
run_id="r1", run_id="r1",
event_type="tool_end", event_type="tool_end",
@@ -480,7 +480,7 @@ class TestDictContentFlag:
sf = get_session_factory() sf = get_session_factory()
store = DbRunEventStore(sf) store = DbRunEventStore(sf)
record = await store.put( await store.put(
thread_id="t1", thread_id="t1",
run_id="r1", run_id="r1",
event_type="tool_end", event_type="tool_end",