Files
deer-flow/backend/packages/harness/deerflow/runtime/events/store/db.py
T
john lee cbf8b194e8 fix(runtime): harden JSONL async I/O and DB put_batch thread validation (#3084)
* fix(runtime): harden JSONL async I/O and DB put_batch thread validation (#2816)

- JsonlRunEventStore: offload all file I/O to asyncio.to_thread() so the
  event loop is never blocked; add per-thread asyncio.Lock to serialise
  concurrent puts and prevent interleaved JSONL lines
- Split _ensure_seq_loaded into a sync _compute_max_seq (runs in thread)
  and an async wrapper; seq counter is recovered from disk on fresh store init
- DbRunEventStore.put_batch: raise ValueError when events span multiple
  thread_ids (previously silently assumed same thread)
- Add test_jsonl_event_store_async_io.py: 12 tests covering lock reuse,
  concurrent seq monotonicity, disk recovery, and mixed-thread batch rejection

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix: address Copilot review comments

- delete_by_thread: pop _write_locks after releasing the lock to prevent
  unbounded growth when threads are repeatedly created and deleted
- tests: add regression guard asserting asyncio.to_thread is called for
  _write_record in put(); assert _write_locks entry removed on delete

* fix(lint): move patch import to local scope to fix ruff I001

* fix(lint): apply ruff check+format fixes to test file

* fix(runtime): address review feedback for JSONL async I/O hardening (#2816)

Use setdefault for atomic lock init in _get_write_lock; pop _write_locks
inside the held lock scope in delete_by_thread; update test docstring
and assert lock entry also cleared on delete.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: rayhpeng <rayhpeng@gmail.com>
2026-05-29 09:27:53 +08:00

316 lines
14 KiB
Python

"""SQLAlchemy-backed RunEventStore implementation.
Persists events to the ``run_events`` table. Trace content is truncated
at ``max_trace_content`` bytes to avoid bloating the database.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import delete, func, select, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
from deerflow.runtime.events.store.base import RunEventStore
from deerflow.runtime.user_context import AUTO, _AutoSentinel, get_current_user, resolve_user_id
from deerflow.utils.time import coerce_iso
logger = logging.getLogger(__name__)
class DbRunEventStore(RunEventStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession], *, max_trace_content: int = 10240):
self._sf = session_factory
self._max_trace_content = max_trace_content
@staticmethod
def _row_to_dict(row: RunEventRow) -> dict:
d = row.to_dict()
d["metadata"] = d.pop("event_metadata", {})
val = d.get("created_at")
if isinstance(val, datetime):
# SQLite drops tzinfo on read despite ``DateTime(timezone=True)``;
# ``coerce_iso`` normalizes naive datetimes as UTC.
d["created_at"] = coerce_iso(val)
d.pop("id", None)
# Restore structured content that was JSON-serialized on write.
raw = d.get("content", "")
metadata = d.get("metadata", {})
if isinstance(raw, str) and (metadata.get("content_is_json") or metadata.get("content_is_dict")):
try:
d["content"] = json.loads(raw)
except (json.JSONDecodeError, ValueError):
# Content looked like JSON but failed to parse;
# keep the raw string as-is.
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
return d
def _truncate_trace(self, category: str, content: Any, metadata: dict | None) -> tuple[Any, dict]:
if category == "trace":
text = content if isinstance(content, str) else json.dumps(content, default=str, ensure_ascii=False)
encoded = text.encode("utf-8")
if len(encoded) > self._max_trace_content:
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
content = encoded[: self._max_trace_content].decode("utf-8", errors="ignore")
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
return content, metadata or {}
@staticmethod
def _content_to_db(content: Any, metadata: dict | None) -> tuple[str, dict]:
metadata = metadata or {}
if isinstance(content, str):
return content, metadata
db_content = json.dumps(content, default=str, ensure_ascii=False)
metadata = {**metadata, "content_is_json": True}
if isinstance(content, dict):
metadata["content_is_dict"] = True
return db_content, metadata
@staticmethod
def _user_id_from_context() -> str | None:
"""Soft read of user_id from contextvar for write paths.
Returns ``None`` (no filter / no stamp) if contextvar is unset,
which is the expected case for background worker writes. HTTP
request writes will have the contextvar set by auth middleware
and get their user_id stamped automatically.
Coerces ``user.id`` to ``str`` at the boundary: ``User.id`` is
typed as ``UUID`` by the auth layer, but ``run_events.user_id``
is ``VARCHAR(64)`` and aiosqlite cannot bind a raw UUID object
to a VARCHAR column ("type 'UUID' is not supported") — the
INSERT would silently roll back and the worker would hang.
"""
user = get_current_user()
return str(user.id) if user is not None else None
@staticmethod
async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None:
"""Return the current max seq while serializing writers per thread.
PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate
results are not lockable rows. As a release-safe workaround, take a
transaction-level advisory lock keyed by thread_id before reading the
aggregate. Other dialects keep the existing row-locking statement.
"""
stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)
bind = session.get_bind()
dialect_name = bind.dialect.name if bind is not None else ""
if dialect_name == "postgresql":
await session.execute(
text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"),
{"thread_id": thread_id},
)
return await session.scalar(stmt)
return await session.scalar(stmt.with_for_update())
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
"""Write a single event — low-frequency path only.
This opens a dedicated transaction with a FOR UPDATE lock to
assign a monotonic *seq*. For high-throughput writes use
:meth:`put_batch`, which acquires the lock once for the whole
batch. Currently the only caller is ``worker.run_agent`` for
the initial ``human_message`` event (once per run).
"""
content, metadata = self._truncate_trace(category, content, metadata)
db_content, metadata = self._content_to_db(content, metadata)
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
max_seq = await self._max_seq_for_thread(session, thread_id)
seq = (max_seq or 0) + 1
row = RunEventRow(
thread_id=thread_id,
run_id=run_id,
user_id=user_id,
event_type=event_type,
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(created_at) if created_at else datetime.now(UTC),
)
session.add(row)
return self._row_to_dict(row)
async def put_batch(self, events):
if not events:
return []
thread_ids = {e["thread_id"] for e in events}
if len(thread_ids) > 1:
raise ValueError(f"put_batch requires all events to belong to the same thread; got {thread_ids!r}")
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# All events belong to the same thread (validated above).
thread_id = events[0]["thread_id"]
max_seq = await self._max_seq_for_thread(session, thread_id)
seq = max_seq or 0
rows = []
for e in events:
seq += 1
content = e.get("content", "")
category = e.get("category", "trace")
metadata = e.get("metadata")
content, metadata = self._truncate_trace(category, content, metadata)
db_content, metadata = self._content_to_db(content, metadata)
row = RunEventRow(
thread_id=e["thread_id"],
run_id=e["run_id"],
user_id=e.get("user_id", user_id),
event_type=e["event_type"],
category=category,
content=db_content,
event_metadata=metadata,
seq=seq,
created_at=datetime.fromisoformat(e["created_at"]) if e.get("created_at") else datetime.now(UTC),
)
session.add(row)
rows.append(row)
return [self._row_to_dict(r) for r in rows]
async def list_messages(
self,
thread_id,
*,
limit=50,
before_seq=None,
after_seq=None,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
stmt = stmt.where(RunEventRow.seq > after_seq)
if after_seq is not None:
# Forward pagination: first `limit` records after cursor
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
else:
# before_seq or default (latest): take last `limit` records, return ascending
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def list_events(
self,
thread_id,
run_id,
*,
event_types=None,
limit=500,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_events")
stmt = select(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if event_types:
stmt = stmt.where(RunEventRow.event_type.in_(event_types))
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
async def list_messages_by_run(
self,
thread_id,
run_id,
*,
limit=50,
before_seq=None,
after_seq=None,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.list_messages_by_run")
stmt = select(RunEventRow).where(
RunEventRow.thread_id == thread_id,
RunEventRow.run_id == run_id,
RunEventRow.category == "message",
)
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
if before_seq is not None:
stmt = stmt.where(RunEventRow.seq < before_seq)
if after_seq is not None:
stmt = stmt.where(RunEventRow.seq > after_seq)
if after_seq is not None:
stmt = stmt.order_by(RunEventRow.seq.asc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
return [self._row_to_dict(r) for r in result.scalars()]
else:
stmt = stmt.order_by(RunEventRow.seq.desc()).limit(limit)
async with self._sf() as session:
result = await session.execute(stmt)
rows = list(result.scalars())
return [self._row_to_dict(r) for r in reversed(rows)]
async def count_messages(
self,
thread_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.count_messages")
stmt = select(func.count()).select_from(RunEventRow).where(RunEventRow.thread_id == thread_id, RunEventRow.category == "message")
if resolved_user_id is not None:
stmt = stmt.where(RunEventRow.user_id == resolved_user_id)
async with self._sf() as session:
return await session.scalar(stmt) or 0
async def delete_by_thread(
self,
thread_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_thread")
async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id]
if resolved_user_id is not None:
count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count
async def delete_by_run(
self,
thread_id,
run_id,
*,
user_id: str | None | _AutoSentinel = AUTO,
):
resolved_user_id = resolve_user_id(user_id, method_name="DbRunEventStore.delete_by_run")
async with self._sf() as session:
count_conditions = [RunEventRow.thread_id == thread_id, RunEventRow.run_id == run_id]
if resolved_user_id is not None:
count_conditions.append(RunEventRow.user_id == resolved_user_id)
count_stmt = select(func.count()).select_from(RunEventRow).where(*count_conditions)
count = await session.scalar(count_stmt) or 0
if count > 0:
await session.execute(delete(RunEventRow).where(*count_conditions))
await session.commit()
return count