mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-22 16:06:50 +00:00
fix(runtime): avoid postgres aggregate row lock (#2962)
This commit is contained in:
@@ -11,7 +11,7 @@ import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
@@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore):
|
||||
user = get_current_user()
|
||||
return str(user.id) if user is not None else None
|
||||
|
||||
@staticmethod
|
||||
async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None:
|
||||
"""Return the current max seq while serializing writers per thread.
|
||||
|
||||
PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate
|
||||
results are not lockable rows. As a release-safe workaround, take a
|
||||
transaction-level advisory lock keyed by thread_id before reading the
|
||||
aggregate. Other dialects keep the existing row-locking statement.
|
||||
"""
|
||||
stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)
|
||||
bind = session.get_bind()
|
||||
dialect_name = bind.dialect.name if bind is not None else ""
|
||||
|
||||
if dialect_name == "postgresql":
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"),
|
||||
{"thread_id": thread_id},
|
||||
)
|
||||
return await session.scalar(stmt)
|
||||
|
||||
return await session.scalar(stmt.with_for_update())
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
|
||||
@@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore):
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
@@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore):
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
|
||||
Reference in New Issue
Block a user