mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 08:25:57 +00:00
fix(persistence): address review feedback on PR #1851
- Fix naive datetime.now() → datetime.now(UTC) in all ORM models - Fix seq race condition in DbRunEventStore.put() with FOR UPDATE and UNIQUE(thread_id, seq) constraint - Encapsulate _store access in RunManager.update_run_completion() - Deduplicate _store.put() logic in RunManager via _persist_to_store() - Add update_run_completion to RunStore ABC + MemoryRunStore - Wire follow_up_to_run_id through the full create path - Add error recovery to RunJournal._flush_sync() lost-event scenario - Add migration note for search_threads breaking change - Fix test_checkpointer_none_fix mock to set database=None Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -317,7 +317,16 @@ async def create_thread(body: ThreadCreateRequest, request: Request) -> ThreadRe
|
|||||||
|
|
||||||
@router.post("/search", response_model=list[ThreadResponse])
|
@router.post("/search", response_model=list[ThreadResponse])
|
||||||
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
async def search_threads(body: ThreadSearchRequest, request: Request) -> list[ThreadResponse]:
|
||||||
"""Search and list threads from the threads_meta table."""
|
"""Search and list threads from the threads_meta table.
|
||||||
|
|
||||||
|
NOTE: Migration from pre-persistence-layer deployments:
|
||||||
|
Threads created via LangGraph Server before this change are NOT
|
||||||
|
automatically indexed in threads_meta. They will not appear in
|
||||||
|
search results until a new run is created on them (which triggers
|
||||||
|
thread_meta upsert in services.py). For bulk migration, run:
|
||||||
|
python -m deerflow.persistence.migrate_threads_from_checkpointer
|
||||||
|
(migration script TBD in a follow-up PR)
|
||||||
|
"""
|
||||||
from app.gateway.deps import get_thread_meta_repo
|
from app.gateway.deps import get_thread_meta_repo
|
||||||
|
|
||||||
repo = get_thread_meta_repo(request)
|
repo = get_thread_meta_repo(request)
|
||||||
|
|||||||
@@ -266,6 +266,17 @@ async def start_run(
|
|||||||
|
|
||||||
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
disconnect = DisconnectMode.cancel if body.on_disconnect == "cancel" else DisconnectMode.continue_
|
||||||
|
|
||||||
|
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
|
||||||
|
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
|
||||||
|
if follow_up_to_run_id is None:
|
||||||
|
run_store = get_run_store(request)
|
||||||
|
try:
|
||||||
|
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
|
||||||
|
if recent_runs and recent_runs[0].get("status") == "success":
|
||||||
|
follow_up_to_run_id = recent_runs[0]["run_id"]
|
||||||
|
except Exception:
|
||||||
|
pass # Don't block run creation
|
||||||
|
|
||||||
try:
|
try:
|
||||||
record = await run_mgr.create_or_reject(
|
record = await run_mgr.create_or_reject(
|
||||||
thread_id,
|
thread_id,
|
||||||
@@ -274,6 +285,7 @@ async def start_run(
|
|||||||
metadata=body.metadata or {},
|
metadata=body.metadata or {},
|
||||||
kwargs={"input": body.input, "config": body.config},
|
kwargs={"input": body.input, "config": body.config},
|
||||||
multitask_strategy=body.multitask_strategy,
|
multitask_strategy=body.multitask_strategy,
|
||||||
|
follow_up_to_run_id=follow_up_to_run_id,
|
||||||
)
|
)
|
||||||
except ConflictError as exc:
|
except ConflictError as exc:
|
||||||
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
raise HTTPException(status_code=409, detail=str(exc)) from exc
|
||||||
@@ -302,17 +314,6 @@ async def start_run(
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", thread_id)
|
logger.warning("Failed to upsert thread_meta for %s (non-fatal)", thread_id)
|
||||||
|
|
||||||
# Resolve follow_up_to_run_id: explicit from request, or auto-detect from latest successful run
|
|
||||||
follow_up_to_run_id = getattr(body, "follow_up_to_run_id", None)
|
|
||||||
if follow_up_to_run_id is None:
|
|
||||||
run_store = get_run_store(request)
|
|
||||||
try:
|
|
||||||
recent_runs = await run_store.list_by_thread(thread_id, limit=1)
|
|
||||||
if recent_runs and recent_runs[0].get("status") == "success":
|
|
||||||
follow_up_to_run_id = recent_runs[0]["run_id"]
|
|
||||||
except Exception:
|
|
||||||
pass # Don't block run creation
|
|
||||||
|
|
||||||
agent_factory = resolve_agent_factory(body.assistant_id)
|
agent_factory = resolve_agent_factory(body.assistant_id)
|
||||||
graph_input = normalize_input(body.input)
|
graph_input = normalize_input(body.input)
|
||||||
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
config = build_run_config(thread_id, body.config, body.metadata, assistant_id=body.assistant_id)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from sqlalchemy import DateTime, String, Text
|
from sqlalchemy import DateTime, String, Text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
@@ -27,4 +27,4 @@ class FeedbackRow(Base):
|
|||||||
comment: Mapped[str | None] = mapped_column(Text)
|
comment: Mapped[str | None] = mapped_column(Text)
|
||||||
# Optional text feedback from the user
|
# Optional text feedback from the user
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from sqlalchemy import JSON, DateTime, Index, String, Text
|
from sqlalchemy import JSON, DateTime, Index, String, Text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
@@ -43,7 +43,7 @@ class RunRow(Base):
|
|||||||
# Follow-up association
|
# Follow-up association
|
||||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
|
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now())
|
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
|
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from sqlalchemy import JSON, DateTime, Index, String, Text
|
from sqlalchemy import JSON, DateTime, Index, String, Text, UniqueConstraint
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from deerflow.persistence.base import Base
|
from deerflow.persistence.base import Base
|
||||||
@@ -22,9 +22,10 @@ class RunEventRow(Base):
|
|||||||
content: Mapped[str] = mapped_column(Text, default="")
|
content: Mapped[str] = mapped_column(Text, default="")
|
||||||
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
|
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
seq: Mapped[int] = mapped_column(nullable=False)
|
seq: Mapped[int] = mapped_column(nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
|
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
|
||||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||||
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
from sqlalchemy import JSON, DateTime, String
|
from sqlalchemy import JSON, DateTime, String
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
@@ -19,5 +19,5 @@ class ThreadMetaRow(Base):
|
|||||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now())
|
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
||||||
|
|||||||
@@ -78,6 +78,7 @@ class RunRepository(RunStore):
|
|||||||
kwargs=None,
|
kwargs=None,
|
||||||
error=None,
|
error=None,
|
||||||
created_at=None,
|
created_at=None,
|
||||||
|
follow_up_to_run_id=None,
|
||||||
):
|
):
|
||||||
now = datetime.now(UTC)
|
now = datetime.now(UTC)
|
||||||
row = RunRow(
|
row = RunRow(
|
||||||
@@ -90,6 +91,7 @@ class RunRepository(RunStore):
|
|||||||
metadata_json=self._safe_json(metadata) or {},
|
metadata_json=self._safe_json(metadata) or {},
|
||||||
kwargs_json=self._safe_json(kwargs) or {},
|
kwargs_json=self._safe_json(kwargs) or {},
|
||||||
error=error,
|
error=error,
|
||||||
|
follow_up_to_run_id=follow_up_to_run_id,
|
||||||
created_at=datetime.fromisoformat(created_at) if created_at else now,
|
created_at=datetime.fromisoformat(created_at) if created_at else now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -54,58 +54,68 @@ class DbRunEventStore(RunEventStore):
|
|||||||
else:
|
else:
|
||||||
db_content = content
|
db_content = content
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
|
async with session.begin():
|
||||||
seq = (max_seq or 0) + 1
|
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||||
row = RunEventRow(
|
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||||
thread_id=thread_id,
|
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||||
run_id=run_id,
|
max_seq = await session.scalar(
|
||||||
event_type=event_type,
|
select(func.max(RunEventRow.seq))
|
||||||
category=category,
|
.where(RunEventRow.thread_id == thread_id)
|
||||||
content=db_content,
|
.with_for_update()
|
||||||
event_metadata=metadata,
|
)
|
||||||
seq=seq,
|
seq = (max_seq or 0) + 1
|
||||||
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
|
row = RunEventRow(
|
||||||
)
|
thread_id=thread_id,
|
||||||
session.add(row)
|
run_id=run_id,
|
||||||
await session.commit()
|
event_type=event_type,
|
||||||
await session.refresh(row)
|
category=category,
|
||||||
|
content=db_content,
|
||||||
|
event_metadata=metadata,
|
||||||
|
seq=seq,
|
||||||
|
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(row)
|
||||||
return self._row_to_dict(row)
|
return self._row_to_dict(row)
|
||||||
|
|
||||||
async def put_batch(self, events):
|
async def put_batch(self, events):
|
||||||
if not events:
|
if not events:
|
||||||
return []
|
return []
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
# Get max seq for the thread (assume all events in batch belong to same thread)
|
async with session.begin():
|
||||||
thread_id = events[0]["thread_id"]
|
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id))
|
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||||
seq = max_seq or 0
|
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||||
rows = []
|
thread_id = events[0]["thread_id"]
|
||||||
for e in events:
|
max_seq = await session.scalar(
|
||||||
seq += 1
|
select(func.max(RunEventRow.seq))
|
||||||
content = e.get("content", "")
|
.where(RunEventRow.thread_id == thread_id)
|
||||||
category = e.get("category", "trace")
|
.with_for_update()
|
||||||
metadata = e.get("metadata")
|
|
||||||
content, metadata = self._truncate_trace(category, content, metadata)
|
|
||||||
if isinstance(content, dict):
|
|
||||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
|
||||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
|
||||||
else:
|
|
||||||
db_content = content
|
|
||||||
row = RunEventRow(
|
|
||||||
thread_id=e["thread_id"],
|
|
||||||
run_id=e["run_id"],
|
|
||||||
event_type=e["event_type"],
|
|
||||||
category=category,
|
|
||||||
content=db_content,
|
|
||||||
event_metadata=metadata,
|
|
||||||
seq=seq,
|
|
||||||
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
|
|
||||||
)
|
)
|
||||||
session.add(row)
|
seq = max_seq or 0
|
||||||
rows.append(row)
|
rows = []
|
||||||
await session.commit()
|
for e in events:
|
||||||
for row in rows:
|
seq += 1
|
||||||
await session.refresh(row)
|
content = e.get("content", "")
|
||||||
|
category = e.get("category", "trace")
|
||||||
|
metadata = e.get("metadata")
|
||||||
|
content, metadata = self._truncate_trace(category, content, metadata)
|
||||||
|
if isinstance(content, dict):
|
||||||
|
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||||
|
metadata = {**(metadata or {}), "content_is_dict": True}
|
||||||
|
else:
|
||||||
|
db_content = content
|
||||||
|
row = RunEventRow(
|
||||||
|
thread_id=e["thread_id"],
|
||||||
|
run_id=e["run_id"],
|
||||||
|
event_type=e["event_type"],
|
||||||
|
category=category,
|
||||||
|
content=db_content,
|
||||||
|
event_metadata=metadata,
|
||||||
|
seq=seq,
|
||||||
|
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
|
||||||
|
)
|
||||||
|
session.add(row)
|
||||||
|
rows.append(row)
|
||||||
return [self._row_to_dict(r) for r in rows]
|
return [self._row_to_dict(r) for r in rows]
|
||||||
|
|
||||||
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
async def list_messages(self, thread_id, *, limit=50, before_seq=None, after_seq=None):
|
||||||
|
|||||||
@@ -386,13 +386,27 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
return
|
return
|
||||||
batch = self._buffer.copy()
|
batch = self._buffer.copy()
|
||||||
self._buffer.clear()
|
self._buffer.clear()
|
||||||
loop.create_task(self._flush_async(batch))
|
task = loop.create_task(self._flush_async(batch))
|
||||||
|
task.add_done_callback(self._on_flush_done)
|
||||||
|
|
||||||
async def _flush_async(self, batch: list[dict]) -> None:
|
async def _flush_async(self, batch: list[dict]) -> None:
|
||||||
try:
|
try:
|
||||||
await self._store.put_batch(batch)
|
await self._store.put_batch(batch)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("RunJournal: failed to flush %d events", len(batch), exc_info=True)
|
logger.warning(
|
||||||
|
"Failed to flush %d events for run %s — returning to buffer",
|
||||||
|
len(batch), self.run_id, exc_info=True,
|
||||||
|
)
|
||||||
|
# Return failed events to buffer for retry on next flush
|
||||||
|
self._buffer = batch + self._buffer
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _on_flush_done(task: asyncio.Task) -> None:
|
||||||
|
if task.cancelled():
|
||||||
|
return
|
||||||
|
exc = task.exception()
|
||||||
|
if exc:
|
||||||
|
logger.warning("Journal flush task failed: %s", exc)
|
||||||
|
|
||||||
def _identify_caller(self, kwargs: dict) -> str:
|
def _identify_caller(self, kwargs: dict) -> str:
|
||||||
for tag in kwargs.get("tags") or []:
|
for tag in kwargs.get("tags") or []:
|
||||||
|
|||||||
@@ -54,6 +54,33 @@ class RunManager:
|
|||||||
self._lock = asyncio.Lock()
|
self._lock = asyncio.Lock()
|
||||||
self._store = store
|
self._store = store
|
||||||
|
|
||||||
|
async def _persist_to_store(self, record: RunRecord, *, follow_up_to_run_id: str | None = None) -> None:
|
||||||
|
"""Best-effort persist run record to backing store."""
|
||||||
|
if self._store is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
await self._store.put(
|
||||||
|
record.run_id,
|
||||||
|
thread_id=record.thread_id,
|
||||||
|
assistant_id=record.assistant_id,
|
||||||
|
status=record.status.value,
|
||||||
|
multitask_strategy=record.multitask_strategy,
|
||||||
|
metadata=record.metadata or {},
|
||||||
|
kwargs=record.kwargs or {},
|
||||||
|
created_at=record.created_at,
|
||||||
|
follow_up_to_run_id=follow_up_to_run_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||||
|
|
||||||
|
async def update_run_completion(self, run_id: str, **kwargs) -> None:
|
||||||
|
"""Persist token usage and completion data to the backing store."""
|
||||||
|
if self._store is not None:
|
||||||
|
try:
|
||||||
|
await self._store.update_run_completion(run_id, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
||||||
|
|
||||||
async def create(
|
async def create(
|
||||||
self,
|
self,
|
||||||
thread_id: str,
|
thread_id: str,
|
||||||
@@ -63,6 +90,7 @@ class RunManager:
|
|||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
|
follow_up_to_run_id: str | None = None,
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Create a new pending run and register it."""
|
"""Create a new pending run and register it."""
|
||||||
run_id = str(uuid.uuid4())
|
run_id = str(uuid.uuid4())
|
||||||
@@ -81,20 +109,7 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
async with self._lock:
|
async with self._lock:
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
if self._store is not None:
|
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||||
try:
|
|
||||||
await self._store.put(
|
|
||||||
run_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
assistant_id=assistant_id,
|
|
||||||
status=RunStatus.pending.value,
|
|
||||||
multitask_strategy=multitask_strategy,
|
|
||||||
metadata=metadata or {},
|
|
||||||
kwargs=kwargs or {},
|
|
||||||
created_at=now,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
|
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@@ -161,6 +176,7 @@ class RunManager:
|
|||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
kwargs: dict | None = None,
|
kwargs: dict | None = None,
|
||||||
multitask_strategy: str = "reject",
|
multitask_strategy: str = "reject",
|
||||||
|
follow_up_to_run_id: str | None = None,
|
||||||
) -> RunRecord:
|
) -> RunRecord:
|
||||||
"""Atomically check for inflight runs and create a new one.
|
"""Atomically check for inflight runs and create a new one.
|
||||||
|
|
||||||
@@ -214,21 +230,7 @@ class RunManager:
|
|||||||
)
|
)
|
||||||
self._runs[run_id] = record
|
self._runs[run_id] = record
|
||||||
|
|
||||||
if self._store is not None:
|
await self._persist_to_store(record, follow_up_to_run_id=follow_up_to_run_id)
|
||||||
try:
|
|
||||||
await self._store.put(
|
|
||||||
run_id,
|
|
||||||
thread_id=thread_id,
|
|
||||||
assistant_id=assistant_id,
|
|
||||||
status=RunStatus.pending.value,
|
|
||||||
multitask_strategy=multitask_strategy,
|
|
||||||
metadata=metadata or {},
|
|
||||||
kwargs=kwargs or {},
|
|
||||||
created_at=now,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to persist run %s to store", run_id, exc_info=True)
|
|
||||||
|
|
||||||
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
logger.info("Run created: run_id=%s thread_id=%s", run_id, thread_id)
|
||||||
return record
|
return record
|
||||||
|
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ class RunStore(abc.ABC):
|
|||||||
kwargs: dict[str, Any] | None = None,
|
kwargs: dict[str, Any] | None = None,
|
||||||
error: str | None = None,
|
error: str | None = None,
|
||||||
created_at: str | None = None,
|
created_at: str | None = None,
|
||||||
|
follow_up_to_run_id: str | None = None,
|
||||||
) -> None: ...
|
) -> None: ...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
@@ -55,5 +56,24 @@ class RunStore(abc.ABC):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def delete(self, run_id: str) -> None: ...
|
async def delete(self, run_id: str) -> None: ...
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def update_run_completion(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
total_input_tokens: int = 0,
|
||||||
|
total_output_tokens: int = 0,
|
||||||
|
total_tokens: int = 0,
|
||||||
|
llm_call_count: int = 0,
|
||||||
|
lead_agent_tokens: int = 0,
|
||||||
|
subagent_tokens: int = 0,
|
||||||
|
middleware_tokens: int = 0,
|
||||||
|
message_count: int = 0,
|
||||||
|
last_ai_message: str | None = None,
|
||||||
|
first_human_message: str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> None: ...
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: ...
|
async def list_pending(self, *, before: str | None = None) -> list[dict[str, Any]]: ...
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ class MemoryRunStore(RunStore):
|
|||||||
kwargs=None,
|
kwargs=None,
|
||||||
error=None,
|
error=None,
|
||||||
created_at=None,
|
created_at=None,
|
||||||
|
follow_up_to_run_id=None,
|
||||||
):
|
):
|
||||||
now = datetime.now(UTC).isoformat()
|
now = datetime.now(UTC).isoformat()
|
||||||
self._runs[run_id] = {
|
self._runs[run_id] = {
|
||||||
@@ -40,6 +41,7 @@ class MemoryRunStore(RunStore):
|
|||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
"kwargs": kwargs or {},
|
"kwargs": kwargs or {},
|
||||||
"error": error,
|
"error": error,
|
||||||
|
"follow_up_to_run_id": follow_up_to_run_id,
|
||||||
"created_at": created_at or now,
|
"created_at": created_at or now,
|
||||||
"updated_at": now,
|
"updated_at": now,
|
||||||
}
|
}
|
||||||
@@ -62,6 +64,14 @@ class MemoryRunStore(RunStore):
|
|||||||
async def delete(self, run_id):
|
async def delete(self, run_id):
|
||||||
self._runs.pop(run_id, None)
|
self._runs.pop(run_id, None)
|
||||||
|
|
||||||
|
async def update_run_completion(self, run_id, *, status, **kwargs):
|
||||||
|
if run_id in self._runs:
|
||||||
|
self._runs[run_id]["status"] = status
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if value is not None:
|
||||||
|
self._runs[run_id][key] = value
|
||||||
|
self._runs[run_id]["updated_at"] = datetime.now(UTC).isoformat()
|
||||||
|
|
||||||
async def list_pending(self, *, before=None):
|
async def list_pending(self, *, before=None):
|
||||||
now = before or datetime.now(UTC).isoformat()
|
now = before or datetime.now(UTC).isoformat()
|
||||||
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
|
results = [r for r in self._runs.values() if r["status"] == "pending" and r["created_at"] <= now]
|
||||||
|
|||||||
@@ -257,16 +257,8 @@ async def run_agent(
|
|||||||
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
logger.warning("Failed to flush journal for run %s", run_id, exc_info=True)
|
||||||
|
|
||||||
# Persist token usage + convenience fields to RunStore
|
# Persist token usage + convenience fields to RunStore
|
||||||
if run_manager._store is not None:
|
completion = journal.get_completion_data()
|
||||||
try:
|
await run_manager.update_run_completion(run_id, status=record.status.value, **completion)
|
||||||
completion = journal.get_completion_data()
|
|
||||||
await run_manager._store.update_run_completion(
|
|
||||||
run_id,
|
|
||||||
status=record.status.value,
|
|
||||||
**completion,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
logger.warning("Failed to persist run completion for %s", run_id, exc_info=True)
|
|
||||||
|
|
||||||
# Sync title from checkpoint to threads_meta.display_name
|
# Sync title from checkpoint to threads_meta.display_name
|
||||||
if thread_meta_repo is not None and checkpointer is not None:
|
if thread_meta_repo is not None and checkpointer is not None:
|
||||||
|
|||||||
@@ -14,9 +14,10 @@ class TestCheckpointerNoneFix:
|
|||||||
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
"""make_checkpointer should return InMemorySaver when config.checkpointer is None."""
|
||||||
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
from deerflow.agents.checkpointer.async_provider import make_checkpointer
|
||||||
|
|
||||||
# Mock get_app_config to return a config with checkpointer=None
|
# Mock get_app_config to return a config with checkpointer=None and database=None
|
||||||
mock_config = MagicMock()
|
mock_config = MagicMock()
|
||||||
mock_config.checkpointer = None
|
mock_config.checkpointer = None
|
||||||
|
mock_config.database = None
|
||||||
|
|
||||||
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
with patch("deerflow.agents.checkpointer.async_provider.get_app_config", return_value=mock_config):
|
||||||
async with make_checkpointer() as checkpointer:
|
async with make_checkpointer() as checkpointer:
|
||||||
|
|||||||
Reference in New Issue
Block a user