mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 07:56:48 +00:00
feat(persistence): add user feedback + follow-up run association
Phase 2-C: feedback and follow-up tracking. - FeedbackRow ORM model (rating +1/-1, optional message_id, comment) - FeedbackRepository with CRUD, list_by_run/thread, aggregate stats - Feedback API endpoints: create, list, stats, delete - follow_up_to_run_id in RunCreateRequest (explicit or auto-detected from latest successful run on the thread) - Worker writes follow_up_to_run_id into human_message event metadata - Gateway deps: feedback_repo factory + getter - 17 new tests (14 FeedbackRepository + 3 follow-up association) - 109 total tests pass, zero regressions Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from deerflow.persistence.models.feedback import FeedbackRow
|
||||
from deerflow.persistence.models.run import RunRow
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
from deerflow.persistence.models.thread_meta import ThreadMetaRow
|
||||
|
||||
__all__ = ["RunEventRow", "RunRow", "ThreadMetaRow"]
|
||||
__all__ = ["FeedbackRow", "RunEventRow", "RunRow", "ThreadMetaRow"]
|
||||
|
||||
@@ -0,0 +1,30 @@
|
||||
"""ORM model for user feedback on runs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from deerflow.persistence.base import Base
|
||||
|
||||
|
||||
class FeedbackRow(Base):
|
||||
__tablename__ = "feedback"
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), nullable=False, index=True)
|
||||
owner_id: Mapped[str | None] = mapped_column(String(64), index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64))
|
||||
# message_id is an optional RunEventStore event identifier —
|
||||
# allows feedback to target a specific message or the entire run
|
||||
|
||||
rating: Mapped[int] = mapped_column(nullable=False)
|
||||
# +1 (thumbs-up) or -1 (thumbs-down)
|
||||
|
||||
comment: Mapped[str | None] = mapped_column(Text)
|
||||
# Optional text feedback from the user
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
||||
@@ -1,4 +1,5 @@
|
||||
from deerflow.persistence.repositories.feedback_repo import FeedbackRepository
|
||||
from deerflow.persistence.repositories.run_repo import RunRepository
|
||||
from deerflow.persistence.repositories.thread_meta_repo import ThreadMetaRepository
|
||||
|
||||
__all__ = ["RunRepository", "ThreadMetaRepository"]
|
||||
__all__ = ["FeedbackRepository", "RunRepository", "ThreadMetaRepository"]
|
||||
|
||||
@@ -0,0 +1,97 @@
|
||||
"""SQLAlchemy-backed feedback storage.
|
||||
|
||||
Each method acquires its own short-lived session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.feedback import FeedbackRow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FeedbackRepository:
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: FeedbackRow) -> dict:
|
||||
d = row.to_dict()
|
||||
val = d.get("created_at")
|
||||
if isinstance(val, datetime):
|
||||
d["created_at"] = val.isoformat()
|
||||
return d
|
||||
|
||||
async def create(
|
||||
self,
|
||||
*,
|
||||
run_id: str,
|
||||
thread_id: str,
|
||||
rating: int,
|
||||
owner_id: str | None = None,
|
||||
message_id: str | None = None,
|
||||
comment: str | None = None,
|
||||
) -> dict:
|
||||
"""Create a feedback record. rating must be +1 or -1."""
|
||||
if rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {rating}")
|
||||
row = FeedbackRow(
|
||||
feedback_id=str(uuid.uuid4()),
|
||||
run_id=run_id,
|
||||
thread_id=thread_id,
|
||||
owner_id=owner_id,
|
||||
message_id=message_id,
|
||||
rating=rating,
|
||||
comment=comment,
|
||||
created_at=datetime.now(UTC),
|
||||
)
|
||||
async with self._sf() as session:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def get(self, feedback_id: str) -> dict | None:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
return self._row_to_dict(row) if row else None
|
||||
|
||||
async def list_by_run(self, thread_id: str, run_id: str, *, limit: int = 100) -> list[dict]:
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id, FeedbackRow.run_id == run_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def list_by_thread(self, thread_id: str, *, limit: int = 100) -> list[dict]:
|
||||
stmt = select(FeedbackRow).where(FeedbackRow.thread_id == thread_id).order_by(FeedbackRow.created_at.asc()).limit(limit)
|
||||
async with self._sf() as session:
|
||||
result = await session.execute(stmt)
|
||||
return [self._row_to_dict(r) for r in result.scalars()]
|
||||
|
||||
async def delete(self, feedback_id: str) -> bool:
|
||||
async with self._sf() as session:
|
||||
row = await session.get(FeedbackRow, feedback_id)
|
||||
if row is None:
|
||||
return False
|
||||
await session.delete(row)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def aggregate_by_run(self, thread_id: str, run_id: str) -> dict:
|
||||
"""Aggregate feedback stats for a run."""
|
||||
items = await self.list_by_run(thread_id, run_id, limit=10000)
|
||||
positive = sum(1 for i in items if i["rating"] == 1)
|
||||
negative = sum(1 for i in items if i["rating"] == -1)
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"total": len(items),
|
||||
"positive": positive,
|
||||
"negative": negative,
|
||||
}
|
||||
@@ -47,6 +47,7 @@ async def run_agent(
|
||||
interrupt_after: list[str] | Literal["*"] | None = None,
|
||||
event_store: Any | None = None,
|
||||
run_events_config: Any | None = None,
|
||||
follow_up_to_run_id: str | None = None,
|
||||
) -> None:
|
||||
"""Execute an agent in the background, publishing events to *bridge*."""
|
||||
|
||||
@@ -69,12 +70,16 @@ async def run_agent(
|
||||
# Write human_message event
|
||||
user_input = _extract_user_input(graph_input)
|
||||
if user_input:
|
||||
msg_metadata = {}
|
||||
if follow_up_to_run_id:
|
||||
msg_metadata["follow_up_to_run_id"] = follow_up_to_run_id
|
||||
await event_store.put(
|
||||
thread_id=thread_id,
|
||||
run_id=run_id,
|
||||
event_type="human_message",
|
||||
category="message",
|
||||
content=user_input,
|
||||
metadata=msg_metadata or None,
|
||||
)
|
||||
journal.set_first_human_message(user_input)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user