mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-25 01:15:58 +00:00
feat(storage): implement unified persistence layer with database models and repositories
This commit is contained in:
@@ -0,0 +1,11 @@
|
||||
from store.repositories.db.feedback import DbFeedbackRepository
|
||||
from store.repositories.db.run import DbRunRepository
|
||||
from store.repositories.db.run_event import DbRunEventRepository
|
||||
from store.repositories.db.thread_meta import DbThreadMetaRepository
|
||||
|
||||
__all__ = [
|
||||
"DbFeedbackRepository",
|
||||
"DbRunRepository",
|
||||
"DbRunEventRepository",
|
||||
"DbThreadMetaRepository",
|
||||
]
|
||||
@@ -0,0 +1,74 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.feedback import Feedback, FeedbackCreate, FeedbackRepositoryProtocol
|
||||
from store.repositories.models.feedback import Feedback as FeedbackModel
|
||||
|
||||
|
||||
def _to_feedback(m: FeedbackModel) -> Feedback:
|
||||
return Feedback(
|
||||
feedback_id=m.feedback_id,
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
rating=m.rating,
|
||||
user_id=m.user_id,
|
||||
message_id=m.message_id,
|
||||
comment=m.comment,
|
||||
created_time=m.created_time,
|
||||
)
|
||||
|
||||
|
||||
class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
if data.rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||
model = FeedbackModel(
|
||||
feedback_id=data.feedback_id,
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
rating=data.rating,
|
||||
user_id=data.user_id,
|
||||
message_id=data.message_id,
|
||||
comment=data.comment,
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_feedback(model)
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_feedback(model) if model else None
|
||||
|
||||
async def list_feedback_by_run(self, run_id: str) -> list[Feedback]:
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel)
|
||||
.where(FeedbackModel.run_id == run_id)
|
||||
.order_by(FeedbackModel.created_time.desc())
|
||||
)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def list_feedback_by_thread(self, thread_id: str) -> list[Feedback]:
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel)
|
||||
.where(FeedbackModel.thread_id == thread_id)
|
||||
.order_by(FeedbackModel.created_time.desc())
|
||||
)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||
existing = await self.get_feedback(feedback_id)
|
||||
if existing is None:
|
||||
return False
|
||||
await self._session.execute(
|
||||
delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
return True
|
||||
@@ -0,0 +1,125 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol
|
||||
from store.repositories.models.run import Run as RunModel
|
||||
|
||||
|
||||
def _to_run(m: RunModel) -> Run:
|
||||
return Run(
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
status=m.status,
|
||||
model_name=m.model_name,
|
||||
multitask_strategy=m.multitask_strategy,
|
||||
error=m.error,
|
||||
follow_up_to_run_id=m.follow_up_to_run_id,
|
||||
metadata=dict(m.meta or {}),
|
||||
kwargs=dict(m.kwargs or {}),
|
||||
total_input_tokens=m.total_input_tokens,
|
||||
total_output_tokens=m.total_output_tokens,
|
||||
total_tokens=m.total_tokens,
|
||||
llm_call_count=m.llm_call_count,
|
||||
lead_agent_tokens=m.lead_agent_tokens,
|
||||
subagent_tokens=m.subagent_tokens,
|
||||
middleware_tokens=m.middleware_tokens,
|
||||
message_count=m.message_count,
|
||||
first_human_message=m.first_human_message,
|
||||
last_ai_message=m.last_ai_message,
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbRunRepository(RunRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_run(self, data: RunCreate) -> Run:
|
||||
model = RunModel(
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
status=data.status,
|
||||
model_name=data.model_name,
|
||||
multitask_strategy=data.multitask_strategy,
|
||||
follow_up_to_run_id=data.follow_up_to_run_id,
|
||||
meta=dict(data.metadata),
|
||||
kwargs=dict(data.kwargs),
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_run(model)
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
result = await self._session.execute(
|
||||
select(RunModel).where(RunModel.run_id == run_id)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_run(model) if model else None
|
||||
|
||||
async def list_runs_by_thread(
|
||||
self, thread_id: str, *, limit: int = 50, offset: int = 0
|
||||
) -> list[Run]:
|
||||
result = await self._session.execute(
|
||||
select(RunModel)
|
||||
.where(RunModel.thread_id == thread_id)
|
||||
.order_by(RunModel.created_time.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_status(
|
||||
self, run_id: str, status: str, *, error: str | None = None
|
||||
) -> None:
|
||||
values: dict = {"status": status}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
values = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
"first_human_message": first_human_message,
|
||||
"last_ai_message": last_ai_message,
|
||||
"error": error,
|
||||
}
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
@@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
||||
from store.repositories.models.run_event import RunEvent as RunEventModel
|
||||
|
||||
|
||||
def _serialize_content(content: str | dict[str, Any], metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
if isinstance(content, dict):
|
||||
return json.dumps(content, default=str, ensure_ascii=False), {**metadata, "content_is_dict": True}
|
||||
return content, metadata
|
||||
|
||||
|
||||
def _deserialize_content(content: str, metadata: dict[str, Any]) -> str | dict[str, Any]:
|
||||
if not metadata.get("content_is_dict"):
|
||||
return content
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
|
||||
def _to_run_event(model: RunEventModel) -> RunEvent:
|
||||
raw_metadata = dict(model.meta or {})
|
||||
metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"}
|
||||
return RunEvent(
|
||||
thread_id=model.thread_id,
|
||||
run_id=model.run_id,
|
||||
event_type=model.event_type,
|
||||
category=model.category,
|
||||
content=_deserialize_content(model.content, raw_metadata),
|
||||
metadata=metadata,
|
||||
seq=model.seq,
|
||||
created_at=model.created_at,
|
||||
)
|
||||
|
||||
|
||||
class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||
if not events:
|
||||
return []
|
||||
|
||||
thread_ids = {event.thread_id for event in events}
|
||||
seq_by_thread: dict[str, int] = {}
|
||||
for thread_id in thread_ids:
|
||||
max_seq = await self._session.scalar(
|
||||
select(func.max(RunEventModel.seq))
|
||||
.where(RunEventModel.thread_id == thread_id)
|
||||
.with_for_update()
|
||||
)
|
||||
seq_by_thread[thread_id] = max_seq or 0
|
||||
|
||||
rows: list[RunEventModel] = []
|
||||
|
||||
for event in events:
|
||||
seq_by_thread[event.thread_id] += 1
|
||||
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
||||
row = RunEventModel(
|
||||
thread_id=event.thread_id,
|
||||
run_id=event.run_id,
|
||||
seq=seq_by_thread[event.thread_id],
|
||||
event_type=event.event_type,
|
||||
category=event.category,
|
||||
content=content,
|
||||
meta=metadata,
|
||||
)
|
||||
if event.created_at is not None:
|
||||
row.created_at = event.created_at
|
||||
self._session.add(row)
|
||||
rows.append(row)
|
||||
|
||||
await self._session.flush()
|
||||
return [_to_run_event(row) for row in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
)
|
||||
if event_types is not None:
|
||||
stmt = stmt.where(RunEventModel.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = (
|
||||
select(RunEventModel)
|
||||
.where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def count_messages(self, thread_id: str) -> int:
|
||||
count = await self._session.scalar(
|
||||
select(func.count())
|
||||
.select_from(RunEventModel)
|
||||
.where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||
)
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_thread(self, thread_id: str) -> int:
|
||||
count = await self._session.scalar(
|
||||
select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id)
|
||||
)
|
||||
await self._session.execute(delete(RunEventModel).where(RunEventModel.thread_id == thread_id))
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str) -> int:
|
||||
count = await self._session.scalar(
|
||||
select(func.count())
|
||||
.select_from(RunEventModel)
|
||||
.where(RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id)
|
||||
)
|
||||
await self._session.execute(
|
||||
delete(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id)
|
||||
)
|
||||
return int(count or 0)
|
||||
@@ -0,0 +1,97 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.thread_meta import ThreadMeta, ThreadMetaCreate, ThreadMetaRepositoryProtocol
|
||||
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||
|
||||
|
||||
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||
return ThreadMeta(
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
display_name=m.display_name,
|
||||
status=m.status,
|
||||
metadata=dict(m.meta or {}),
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
model = ThreadMetaModel(
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
display_name=data.display_name,
|
||||
status=data.status,
|
||||
meta=dict(data.metadata),
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_thread_meta(model)
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_thread_meta(model) if model else None
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
values: dict = {}
|
||||
if display_name is not None:
|
||||
values["display_name"] = display_name
|
||||
if status is not None:
|
||||
values["status"] = status
|
||||
if metadata is not None:
|
||||
values["meta"] = dict(metadata)
|
||||
if not values:
|
||||
return
|
||||
await self._session.execute(
|
||||
update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
stmt = select(ThreadMetaModel)
|
||||
|
||||
if status is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.status == status)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.user_id == user_id)
|
||||
if assistant_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||
if metadata:
|
||||
for key, value in metadata.items():
|
||||
stmt = stmt.where(ThreadMetaModel.meta[key].as_string() == str(value))
|
||||
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc())
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_thread_meta(m) for m in result.scalars().all()]
|
||||
Reference in New Issue
Block a user