mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 08:55:59 +00:00
feat(feedback): add delete_by_run() and list_by_thread_grouped()
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -162,6 +162,44 @@ class FeedbackRepository:
|
|||||||
await session.refresh(row)
|
await session.refresh(row)
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
|
async def delete_by_run(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
|
) -> bool:
|
||||||
|
"""Delete the current user's feedback for a run. Returns True if a record was deleted."""
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.delete_by_run")
|
||||||
|
async with self._sf() as session:
|
||||||
|
stmt = select(FeedbackRow).where(
|
||||||
|
FeedbackRow.thread_id == thread_id,
|
||||||
|
FeedbackRow.run_id == run_id,
|
||||||
|
FeedbackRow.user_id == resolved_user_id,
|
||||||
|
)
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
return False
|
||||||
|
await session.delete(row)
|
||||||
|
await session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def list_by_thread_grouped(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None | _AutoSentinel = AUTO,
|
||||||
|
) -> dict[str, dict]:
|
||||||
|
"""Return feedback grouped by run_id for a thread: {run_id: feedback_dict}."""
|
||||||
|
resolved_user_id = resolve_user_id(user_id, method_name="FeedbackRepository.list_by_thread_grouped")
|
||||||
|
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id)
|
||||||
|
if resolved_user_id is not None:
|
||||||
|
stmt = stmt.where(FeedbackRow.user_id == resolved_user_id)
|
||||||
|
async with self._sf() as session:
|
||||||
|
result = await session.execute(stmt)
|
||||||
|
return {row.run_id: self._row_to_dict(row) for row in result.scalars()}
|
||||||
|
|
||||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
||||||
"""Aggregate feedback stats for a run using database-side counting."""
|
"""Aggregate feedback stats for a run using database-side counting."""
|
||||||
stmt = select(
|
stmt = select(
|
||||||
|
|||||||
@@ -190,6 +190,44 @@ class TestFeedbackRepository:
|
|||||||
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
|
await repo.upsert(run_id="r1", thread_id="t1", rating=0, user_id="u1")
|
||||||
await _cleanup()
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_by_run(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||||
|
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||||
|
assert deleted is True
|
||||||
|
results = await repo.list_by_run("t1", "r1", user_id="u1")
|
||||||
|
assert len(results) == 0
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_delete_by_run_nonexistent(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
deleted = await repo.delete_by_run(thread_id="t1", run_id="r1", user_id="u1")
|
||||||
|
assert deleted is False
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_by_thread_grouped(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
await repo.upsert(run_id="r1", thread_id="t1", rating=1, user_id="u1")
|
||||||
|
await repo.upsert(run_id="r2", thread_id="t1", rating=-1, user_id="u1")
|
||||||
|
await repo.upsert(run_id="r3", thread_id="t2", rating=1, user_id="u1")
|
||||||
|
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||||
|
assert "r1" in grouped
|
||||||
|
assert "r2" in grouped
|
||||||
|
assert "r3" not in grouped
|
||||||
|
assert grouped["r1"]["rating"] == 1
|
||||||
|
assert grouped["r2"]["rating"] == -1
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_list_by_thread_grouped_empty(self, tmp_path):
|
||||||
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
|
grouped = await repo.list_by_thread_grouped("t1", user_id="u1")
|
||||||
|
assert grouped == {}
|
||||||
|
await _cleanup()
|
||||||
|
|
||||||
|
|
||||||
# -- Follow-up association --
|
# -- Follow-up association --
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user