mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-23 00:16:48 +00:00
fix(events): serialize structured db event content (#2762)
This commit is contained in:
@@ -9,6 +9,7 @@ from __future__ import annotations
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from datetime import UTC, datetime
|
from datetime import UTC, datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete, func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||||
@@ -33,20 +34,21 @@ class DbRunEventStore(RunEventStore):
|
|||||||
if isinstance(val, datetime):
|
if isinstance(val, datetime):
|
||||||
d["created_at"] = val.isoformat()
|
d["created_at"] = val.isoformat()
|
||||||
d.pop("id", None)
|
d.pop("id", None)
|
||||||
# Restore dict content that was JSON-serialized on write
|
# Restore structured content that was JSON-serialized on write.
|
||||||
raw = d.get("content", "")
|
raw = d.get("content", "")
|
||||||
if isinstance(raw, str) and d.get("metadata", {}).get("content_is_dict"):
|
metadata = d.get("metadata", {})
|
||||||
|
if isinstance(raw, str) and (metadata.get("content_is_json") or metadata.get("content_is_dict")):
|
||||||
try:
|
try:
|
||||||
d["content"] = json.loads(raw)
|
d["content"] = json.loads(raw)
|
||||||
except (json.JSONDecodeError, ValueError):
|
except (json.JSONDecodeError, ValueError):
|
||||||
# Content looked like JSON (content_is_dict flag) but failed to parse;
|
# Content looked like JSON but failed to parse;
|
||||||
# keep the raw string as-is.
|
# keep the raw string as-is.
|
||||||
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
|
logger.debug("Failed to deserialize content as JSON for event seq=%s", d.get("seq"))
|
||||||
return d
|
return d
|
||||||
|
|
||||||
def _truncate_trace(self, category: str, content: str | dict, metadata: dict | None) -> tuple[str | dict, dict]:
|
def _truncate_trace(self, category: str, content: Any, metadata: dict | None) -> tuple[Any, dict]:
|
||||||
if category == "trace":
|
if category == "trace":
|
||||||
text = json.dumps(content, default=str, ensure_ascii=False) if isinstance(content, dict) else content
|
text = content if isinstance(content, str) else json.dumps(content, default=str, ensure_ascii=False)
|
||||||
encoded = text.encode("utf-8")
|
encoded = text.encode("utf-8")
|
||||||
if len(encoded) > self._max_trace_content:
|
if len(encoded) > self._max_trace_content:
|
||||||
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
|
# Truncate by bytes, then decode back (may cut a multi-byte char, so use errors="ignore")
|
||||||
@@ -54,6 +56,18 @@ class DbRunEventStore(RunEventStore):
|
|||||||
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
metadata = {**(metadata or {}), "content_truncated": True, "original_byte_length": len(encoded)}
|
||||||
return content, metadata or {}
|
return content, metadata or {}
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _content_to_db(content: Any, metadata: dict | None) -> tuple[str, dict]:
|
||||||
|
metadata = metadata or {}
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content, metadata
|
||||||
|
|
||||||
|
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
||||||
|
metadata = {**metadata, "content_is_json": True}
|
||||||
|
if isinstance(content, dict):
|
||||||
|
metadata["content_is_dict"] = True
|
||||||
|
return db_content, metadata
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _user_id_from_context() -> str | None:
|
def _user_id_from_context() -> str | None:
|
||||||
"""Soft read of user_id from contextvar for write paths.
|
"""Soft read of user_id from contextvar for write paths.
|
||||||
@@ -82,11 +96,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
the initial ``human_message`` event (once per run).
|
the initial ``human_message`` event (once per run).
|
||||||
"""
|
"""
|
||||||
content, metadata = self._truncate_trace(category, content, metadata)
|
content, metadata = self._truncate_trace(category, content, metadata)
|
||||||
if isinstance(content, dict):
|
db_content, metadata = self._content_to_db(content, metadata)
|
||||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
|
||||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
|
||||||
else:
|
|
||||||
db_content = content
|
|
||||||
user_id = self._user_id_from_context()
|
user_id = self._user_id_from_context()
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
async with session.begin():
|
async with session.begin():
|
||||||
@@ -128,11 +138,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
category = e.get("category", "trace")
|
category = e.get("category", "trace")
|
||||||
metadata = e.get("metadata")
|
metadata = e.get("metadata")
|
||||||
content, metadata = self._truncate_trace(category, content, metadata)
|
content, metadata = self._truncate_trace(category, content, metadata)
|
||||||
if isinstance(content, dict):
|
db_content, metadata = self._content_to_db(content, metadata)
|
||||||
db_content = json.dumps(content, default=str, ensure_ascii=False)
|
|
||||||
metadata = {**(metadata or {}), "content_is_dict": True}
|
|
||||||
else:
|
|
||||||
db_content = content
|
|
||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=e["thread_id"],
|
thread_id=e["thread_id"],
|
||||||
run_id=e["run_id"],
|
run_id=e["run_id"],
|
||||||
|
|||||||
@@ -310,6 +310,28 @@ class TestDbRunEventStore:
|
|||||||
|
|
||||||
await close_engine()
|
await close_engine()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_structured_content_round_trips(self, tmp_path):
|
||||||
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||||
|
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||||
|
|
||||||
|
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||||
|
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||||
|
s = DbRunEventStore(get_session_factory())
|
||||||
|
|
||||||
|
content = [{"type": "text", "text": "hello"}, {"type": "image_url", "image_url": {"url": "https://example.test/a.png"}}]
|
||||||
|
record = await s.put(thread_id="t1", run_id="r1", event_type="ai_message", category="message", content=content)
|
||||||
|
|
||||||
|
assert record["content"] == content
|
||||||
|
assert record["metadata"]["content_is_json"] is True
|
||||||
|
assert "content_is_dict" not in record["metadata"]
|
||||||
|
|
||||||
|
messages = await s.list_messages("t1")
|
||||||
|
assert messages[0]["content"] == content
|
||||||
|
assert messages[0]["metadata"]["content_is_json"] is True
|
||||||
|
|
||||||
|
await close_engine()
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
async def test_pagination(self, tmp_path):
|
async def test_pagination(self, tmp_path):
|
||||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||||
@@ -373,6 +395,55 @@ class TestDbRunEventStore:
|
|||||||
assert seqs == list(range(1, 51))
|
assert seqs == list(range(1, 51))
|
||||||
await close_engine()
|
await close_engine()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_put_batch_accepts_structured_content(self, tmp_path):
|
||||||
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||||
|
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||||
|
|
||||||
|
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||||
|
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||||
|
s = DbRunEventStore(get_session_factory())
|
||||||
|
|
||||||
|
content = [{"messages": [{"type": "ai", "content": ""}]}]
|
||||||
|
results = await s.put_batch(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"thread_id": "t1",
|
||||||
|
"run_id": "r1",
|
||||||
|
"event_type": "run.end",
|
||||||
|
"category": "outputs",
|
||||||
|
"content": content,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
assert results[0]["content"] == content
|
||||||
|
assert results[0]["metadata"]["content_is_json"] is True
|
||||||
|
|
||||||
|
events = await s.list_events("t1", "r1")
|
||||||
|
assert events[0]["content"] == content
|
||||||
|
assert events[0]["metadata"]["content_is_json"] is True
|
||||||
|
|
||||||
|
await close_engine()
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_dict_content_keeps_legacy_metadata_flag(self, tmp_path):
|
||||||
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||||
|
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||||
|
|
||||||
|
url = f"sqlite+aiosqlite:///{tmp_path / 'test.db'}"
|
||||||
|
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||||
|
s = DbRunEventStore(get_session_factory())
|
||||||
|
|
||||||
|
content = {"status": "success"}
|
||||||
|
record = await s.put(thread_id="t1", run_id="r1", event_type="run.end", category="outputs", content=content)
|
||||||
|
|
||||||
|
assert record["content"] == content
|
||||||
|
assert record["metadata"]["content_is_json"] is True
|
||||||
|
assert record["metadata"]["content_is_dict"] is True
|
||||||
|
|
||||||
|
await close_engine()
|
||||||
|
|
||||||
|
|
||||||
# -- Factory tests --
|
# -- Factory tests --
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user