mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
refactor(persistence): remove UTFJSON, use engine-level json_serializer + datetime.now()
- Replace custom UTFJSON type with standard sqlalchemy.JSON in all ORM models. Add json_serializer=json.dumps(ensure_ascii=False) to all create_async_engine calls so non-ASCII text (Chinese etc.) is stored as-is in both SQLite and Postgres. - Change ORM datetime defaults from datetime.now(UTC) to datetime.now(), remove UTC imports. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -12,14 +12,45 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
_json_serializer = lambda obj: json.dumps(obj, ensure_ascii=False)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_engine: AsyncEngine | None = None
|
_engine: AsyncEngine | None = None
|
||||||
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
_session_factory: async_sessionmaker[AsyncSession] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def _auto_create_postgres_db(url: str) -> None:
|
||||||
|
"""Connect to the ``postgres`` maintenance DB and CREATE DATABASE.
|
||||||
|
|
||||||
|
The target database name is extracted from *url*. The connection is
|
||||||
|
made to the default ``postgres`` database on the same server using
|
||||||
|
``AUTOCOMMIT`` isolation (CREATE DATABASE cannot run inside a
|
||||||
|
transaction).
|
||||||
|
"""
|
||||||
|
from sqlalchemy import text
|
||||||
|
from sqlalchemy.engine.url import make_url
|
||||||
|
|
||||||
|
parsed = make_url(url)
|
||||||
|
db_name = parsed.database
|
||||||
|
if not db_name:
|
||||||
|
raise ValueError("Cannot auto-create database: no database name in URL")
|
||||||
|
|
||||||
|
# Connect to the default 'postgres' database to issue CREATE DATABASE
|
||||||
|
maint_url = parsed.set(database="postgres")
|
||||||
|
maint_engine = create_async_engine(maint_url, isolation_level="AUTOCOMMIT")
|
||||||
|
try:
|
||||||
|
async with maint_engine.connect() as conn:
|
||||||
|
await conn.execute(text(f'CREATE DATABASE "{db_name}"'))
|
||||||
|
logger.info("Auto-created PostgreSQL database: %s", db_name)
|
||||||
|
finally:
|
||||||
|
await maint_engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
async def init_engine(
|
async def init_engine(
|
||||||
backend: str,
|
backend: str,
|
||||||
*,
|
*,
|
||||||
@@ -53,13 +84,14 @@ async def init_engine(
|
|||||||
import os
|
import os
|
||||||
|
|
||||||
os.makedirs(sqlite_dir or ".", exist_ok=True)
|
os.makedirs(sqlite_dir or ".", exist_ok=True)
|
||||||
_engine = create_async_engine(url, echo=echo)
|
_engine = create_async_engine(url, echo=echo, json_serializer=_json_serializer)
|
||||||
elif backend == "postgres":
|
elif backend == "postgres":
|
||||||
_engine = create_async_engine(
|
_engine = create_async_engine(
|
||||||
url,
|
url,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
pool_size=pool_size,
|
pool_size=pool_size,
|
||||||
pool_pre_ping=True,
|
pool_pre_ping=True,
|
||||||
|
json_serializer=_json_serializer,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown persistence backend: {backend!r}")
|
raise ValueError(f"Unknown persistence backend: {backend!r}")
|
||||||
@@ -76,8 +108,21 @@ async def init_engine(
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
async with _engine.begin() as conn:
|
async with _engine.begin() as conn:
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
except Exception as exc:
|
||||||
|
if backend == "postgres" and "does not exist" in str(exc):
|
||||||
|
# Database not yet created — attempt to auto-create it, then retry.
|
||||||
|
await _auto_create_postgres_db(url)
|
||||||
|
# Rebuild engine against the now-existing database
|
||||||
|
await _engine.dispose()
|
||||||
|
_engine = create_async_engine(url, echo=echo, pool_size=pool_size, pool_pre_ping=True, json_serializer=_json_serializer)
|
||||||
|
_session_factory = async_sessionmaker(_engine, expire_on_commit=False)
|
||||||
|
async with _engine.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
else:
|
||||||
|
raise
|
||||||
|
|
||||||
logger.info("Persistence engine initialized: backend=%s", backend)
|
logger.info("Persistence engine initialized: backend=%s", backend)
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import String, Text
|
from sqlalchemy import DateTime, String, Text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from deerflow.persistence.base import Base
|
from deerflow.persistence.base import Base
|
||||||
@@ -27,4 +27,4 @@ class FeedbackRow(Base):
|
|||||||
comment: Mapped[str | None] = mapped_column(Text)
|
comment: Mapped[str | None] = mapped_column(Text)
|
||||||
# Optional text feedback from the user
|
# Optional text feedback from the user
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import JSON, Index, String, Text
|
from sqlalchemy import JSON, DateTime, Index, String, Text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from deerflow.persistence.base import Base
|
from deerflow.persistence.base import Base
|
||||||
@@ -43,7 +43,7 @@ class RunRow(Base):
|
|||||||
# Follow-up association
|
# Follow-up association
|
||||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
|
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64))
|
||||||
|
|
||||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
||||||
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now())
|
||||||
|
|
||||||
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
|
__table_args__ = (Index("ix_runs_thread_status", "thread_id", "status"),)
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import JSON, Index, String, Text
|
from sqlalchemy import JSON, DateTime, Index, String, Text
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from deerflow.persistence.base import Base
|
from deerflow.persistence.base import Base
|
||||||
@@ -22,7 +22,7 @@ class RunEventRow(Base):
|
|||||||
content: Mapped[str] = mapped_column(Text, default="")
|
content: Mapped[str] = mapped_column(Text, default="")
|
||||||
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
|
event_metadata: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
seq: Mapped[int] = mapped_column(nullable=False)
|
seq: Mapped[int] = mapped_column(nullable=False)
|
||||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import UTC, datetime
|
from datetime import datetime
|
||||||
|
|
||||||
from sqlalchemy import JSON, String
|
from sqlalchemy import JSON, DateTime, String
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
from deerflow.persistence.base import Base
|
from deerflow.persistence.base import Base
|
||||||
@@ -19,5 +19,5 @@ class ThreadMetaRow(Base):
|
|||||||
display_name: Mapped[str | None] = mapped_column(String(256))
|
display_name: Mapped[str | None] = mapped_column(String(256))
|
||||||
status: Mapped[str] = mapped_column(String(20), default="idle")
|
status: Mapped[str] = mapped_column(String(20), default="idle")
|
||||||
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
metadata_json: Mapped[dict] = mapped_column(JSON, default=dict)
|
||||||
created_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC))
|
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now())
|
||||||
updated_at: Mapped[datetime] = mapped_column(default=lambda: datetime.now(UTC), onupdate=lambda: datetime.now(UTC))
|
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now())
|
||||||
|
|||||||
Reference in New Issue
Block a user