mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
style(storage): format storage package
This commit is contained in:
@@ -67,9 +67,7 @@ class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||
return _to_feedback(model)
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
result = await self._session.execute(select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_feedback(model) if model else None
|
||||
|
||||
@@ -112,9 +110,7 @@ class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||
existing = await self.get_feedback(feedback_id)
|
||||
if existing is None:
|
||||
return False
|
||||
await self._session.execute(
|
||||
delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id)
|
||||
)
|
||||
await self._session.execute(delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||
return True
|
||||
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||
|
||||
@@ -64,9 +64,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
return _to_run(model)
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
result = await self._session.execute(
|
||||
select(RunModel).where(RunModel.run_id == run_id)
|
||||
)
|
||||
result = await self._session.execute(select(RunModel).where(RunModel.run_id == run_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_run(model) if model else None
|
||||
|
||||
@@ -85,15 +83,11 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_status(
|
||||
self, run_id: str, status: str, *, error: str | None = None
|
||||
) -> None:
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
|
||||
values: dict = {"status": status}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
|
||||
@@ -106,11 +100,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
else:
|
||||
before_dt = datetime.fromisoformat(before)
|
||||
|
||||
result = await self._session.execute(
|
||||
select(RunModel)
|
||||
.where(RunModel.status == "pending", RunModel.created_time <= before_dt)
|
||||
.order_by(RunModel.created_time.asc())
|
||||
)
|
||||
result = await self._session.execute(select(RunModel).where(RunModel.status == "pending", RunModel.created_time <= before_dt).order_by(RunModel.created_time.asc()))
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_completion(
|
||||
@@ -147,9 +137,7 @@ class DbRunRepository(RunRepositoryProtocol):
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(
|
||||
update(RunModel).where(RunModel.run_id == run_id).values(**values)
|
||||
)
|
||||
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = RunModel.status.in_(("success", "error"))
|
||||
|
||||
@@ -158,13 +158,10 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = (
|
||||
select(RunEventModel)
|
||||
.where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
@@ -182,11 +179,7 @@ class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
stmt = (
|
||||
select(func.count())
|
||||
.select_from(RunEventModel)
|
||||
.where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||
)
|
||||
stmt = select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(stmt)
|
||||
|
||||
@@ -55,12 +55,12 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
return _to_thread_meta(model) if model else None
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
values: dict = {}
|
||||
if display_name is not None:
|
||||
@@ -71,21 +71,20 @@ class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
values["meta"] = dict(metadata)
|
||||
if not values:
|
||||
return
|
||||
await self._session.execute(
|
||||
update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||
await self._session.execute(update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
stmt = select(ThreadMetaModel)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user