Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d7a2fff7e0 | |||
| eabd78ce4e | |||
| 45060a9ffc | |||
| 722c690f4f | |||
| 533d3fbfee | |||
| d6b3a277a5 | |||
| def2a3ad79 | |||
| 3c0b42d836 | |||
| 34ec205e1d | |||
| 11a9041b65 | |||
| d3066a1746 | |||
| 485f8a2bf2 |
@@ -0,0 +1,401 @@
|
||||
# Storage Package Design
|
||||
|
||||
## Background
|
||||
|
||||
DeerFlow currently has several persistence responsibilities spread across app, gateway, runtime, and legacy persistence modules. This makes the persistence boundary difficult to reason about and creates several migration risks:
|
||||
|
||||
- Routers and runtime services can accidentally depend on concrete persistence implementations instead of stable contracts.
|
||||
- User/auth, run metadata, thread metadata, feedback, run events, and checkpointer setup are initialized through different paths.
|
||||
- Some persistence behavior is duplicated between memory, SQLite, and PostgreSQL-oriented code paths.
|
||||
- Incremental migration is hard because app-level code and storage-level code are coupled.
|
||||
- Adding or validating another SQL backend requires touching app/runtime code instead of a storage-owned package.
|
||||
|
||||
The storage package is introduced to make application data persistence a package-level capability with explicit contracts, a clear boundary, and SQL backend compatibility.
|
||||
|
||||
## Goals
|
||||
|
||||
- Provide a standalone `packages/storage` package for durable application data.
|
||||
- Support SQLite, PostgreSQL, and MySQL through a shared persistence construction flow.
|
||||
- Keep LangGraph checkpointer initialization compatible with the same database backend.
|
||||
- Expose repository contracts as the only package-level data access boundary.
|
||||
- Let the app layer depend on app-owned adapters under `app.infra.storage`, not on storage DB implementation classes.
|
||||
- Allow the app/gateway migration to happen in small steps without forcing a large rewrite.
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- This design does not remove legacy persistence in the first PR.
|
||||
- This design does not move routers directly onto storage package models.
|
||||
- This design does not make app routers own SQLAlchemy sessions.
|
||||
- Cron persistence is intentionally out of scope for the storage package foundation.
|
||||
- Memory backend is not part of the durable storage package. Memory compatibility, if still needed by app runtime, belongs outside `packages/storage`.
|
||||
|
||||
## Storage Design Principles
|
||||
|
||||
### Package-Owned Durable Storage
|
||||
|
||||
`packages/storage` owns durable application data persistence. It defines:
|
||||
|
||||
- configuration shape for storage-backed persistence
|
||||
- SQLAlchemy models
|
||||
- repository contracts and DTOs
|
||||
- SQL repository implementations
|
||||
- persistence factory functions
|
||||
- compatibility helpers for config-driven initialization
|
||||
|
||||
The package should be usable without importing `app.gateway`, routers, auth providers, or runtime-specific gateway objects.
|
||||
|
||||
### SQL Backend Compatibility
|
||||
|
||||
The package supports three SQL backends:
|
||||
|
||||
- SQLite for local/single-node deployments
|
||||
- PostgreSQL for production multi-node deployments
|
||||
- MySQL for deployments that standardize on MySQL
|
||||
|
||||
Backend-specific differences are handled inside the storage package:
|
||||
|
||||
- SQLAlchemy async engine URL construction
|
||||
- LangGraph checkpointer connection-string compatibility
|
||||
- JSON metadata filtering across SQLite/PostgreSQL/MySQL
|
||||
- SQL dialect behavior around locking, aggregation, and JSON type semantics
|
||||
|
||||
### Unified Persistence Bundle
|
||||
|
||||
Storage initialization returns an `AppPersistence` bundle:
|
||||
|
||||
```python
|
||||
@dataclass(slots=True)
|
||||
class AppPersistence:
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
setup: Callable[[], Awaitable[None]]
|
||||
aclose: Callable[[], Awaitable[None]]
|
||||
```
|
||||
|
||||
The app runtime can initialize persistence once, call `setup()`, and then inject:
|
||||
|
||||
- `checkpointer`
|
||||
- `session_factory`
|
||||
- repository adapters
|
||||
|
||||
This keeps checkpointer and application data aligned to the same backend without requiring routers to understand database configuration.
|
||||
|
||||
## Package Layout
|
||||
|
||||
```text
|
||||
backend/packages/storage/
|
||||
store/
|
||||
config/
|
||||
storage_config.py
|
||||
app_config.py
|
||||
persistence/
|
||||
factory.py
|
||||
types.py
|
||||
base_model.py
|
||||
json_compat.py
|
||||
drivers/
|
||||
sqlite.py
|
||||
postgres.py
|
||||
mysql.py
|
||||
repositories/
|
||||
contracts/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
models/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
db/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
factory.py
|
||||
```
|
||||
|
||||
## Persistence Construction
|
||||
|
||||
The primary storage entrypoint is:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_storage_config
|
||||
|
||||
persistence = await create_persistence_from_storage_config(storage_config)
|
||||
await persistence.setup()
|
||||
```
|
||||
|
||||
For app-level compatibility with existing database config shape:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
|
||||
persistence = await create_persistence_from_database_config(config.database)
|
||||
await persistence.setup()
|
||||
```
|
||||
|
||||
Expected app startup flow:
|
||||
|
||||
```python
|
||||
persistence = await create_persistence_from_database_config(config.database)
|
||||
await persistence.setup()
|
||||
|
||||
app.state.persistence = persistence
|
||||
app.state.checkpointer = persistence.checkpointer
|
||||
app.state.session_factory = persistence.session_factory
|
||||
```
|
||||
|
||||
Expected app shutdown flow:
|
||||
|
||||
```python
|
||||
await app.state.persistence.aclose()
|
||||
```
|
||||
|
||||
## Repository Contract Design
|
||||
|
||||
Repository contracts are the storage package's public data access boundary. They live under `store.repositories.contracts` and are re-exported from `store.repositories`.
|
||||
|
||||
The key contract groups are:
|
||||
|
||||
- `UserRepositoryProtocol`
|
||||
- `RunRepositoryProtocol`
|
||||
- `ThreadMetaRepositoryProtocol`
|
||||
- `FeedbackRepositoryProtocol`
|
||||
- `RunEventRepositoryProtocol`
|
||||
|
||||
Each contract owns:
|
||||
|
||||
- input DTOs, such as `UserCreate`, `RunCreate`, `ThreadMetaCreate`
|
||||
- output DTOs, such as `User`, `Run`, `ThreadMeta`
|
||||
- repository protocol methods
|
||||
- domain-specific exceptions when needed, such as `InvalidMetadataFilterError`
|
||||
|
||||
Repository construction is session-based:
|
||||
|
||||
```python
|
||||
from store.repositories import build_run_repository
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
run = await repo.get_run(run_id)
|
||||
```
|
||||
|
||||
This keeps transaction ownership explicit. The storage package does not hide commits or session lifecycle inside global singletons.
|
||||
|
||||
## App/Infra Calling Contract
|
||||
|
||||
The app layer should not call `store.repositories.db.*` directly. The intended app boundary is `app.infra.storage`.
|
||||
|
||||
`app.infra.storage` is responsible for:
|
||||
|
||||
- receiving `session_factory` from FastAPI runtime initialization
|
||||
- owning session lifecycle for app-facing repository methods
|
||||
- translating storage DTOs to app/gateway DTOs only when needed
|
||||
- preserving the existing app-facing names during migration
|
||||
- depending on storage repository protocols, not concrete DB classes
|
||||
|
||||
Expected adapter pattern:
|
||||
|
||||
```python
|
||||
class StorageRunRepository(RunRepositoryProtocol):
|
||||
def __init__(self, session_factory):
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def get_run(self, run_id: str):
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
return await repo.get_run(run_id)
|
||||
```
|
||||
|
||||
For gateway compatibility, app state can keep existing names while the implementation changes:
|
||||
|
||||
```python
|
||||
app.state.run_store = StorageRunStore(run_repository)
|
||||
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
|
||||
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
|
||||
app.state.run_event_store = StorageRunEventStore(run_event_repository)
|
||||
app.state.checkpointer = persistence.checkpointer
|
||||
app.state.session_factory = persistence.session_factory
|
||||
```
|
||||
|
||||
The app-facing objects may expose legacy method names during migration, but their internal data access should go through storage contracts.
|
||||
|
||||
## Boundary Rules
|
||||
|
||||
### Allowed Calls
|
||||
|
||||
Storage package callers may use:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.persistence import create_persistence_from_storage_config
|
||||
from store.repositories import build_run_repository
|
||||
from store.repositories import build_user_repository
|
||||
from store.repositories import build_thread_meta_repository
|
||||
from store.repositories import build_feedback_repository
|
||||
from store.repositories import build_run_event_repository
|
||||
from store.repositories import RunRepositoryProtocol
|
||||
from store.repositories import UserRepositoryProtocol
|
||||
```
|
||||
|
||||
App layer callers should use:
|
||||
|
||||
```python
|
||||
from app.infra.storage import StorageRunRepository
|
||||
from app.infra.storage import StorageUserDataRepository
|
||||
from app.infra.storage import StorageThreadMetaRepository
|
||||
from app.infra.storage import StorageFeedbackRepository
|
||||
from app.infra.storage import StorageRunEventRepository
|
||||
```
|
||||
|
||||
### Prohibited Calls
|
||||
|
||||
App/gateway/router/auth code must not import:
|
||||
|
||||
```python
|
||||
from store.repositories.db import DbRunRepository
|
||||
from store.repositories.models import Run
|
||||
from store.persistence.base_model import MappedBase
|
||||
```
|
||||
|
||||
Routers must not:
|
||||
|
||||
- create SQLAlchemy engines
|
||||
- create SQLAlchemy sessions directly
|
||||
- call storage DB repository classes directly
|
||||
- commit/rollback storage transactions directly unless explicitly scoped by an infra adapter
|
||||
- depend on storage SQLAlchemy model classes
|
||||
|
||||
Storage package code must not import:
|
||||
|
||||
```python
|
||||
import app.gateway
|
||||
import app.infra
|
||||
import deerflow.runtime
|
||||
```
|
||||
|
||||
The dependency direction is:
|
||||
|
||||
```text
|
||||
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
|
||||
```
|
||||
|
||||
The reverse direction is forbidden.
|
||||
|
||||
## Checkpointer Compatibility
|
||||
|
||||
The storage persistence bundle initializes the LangGraph checkpointer alongside application data persistence.
|
||||
|
||||
Backend-specific notes:
|
||||
|
||||
- SQLite uses `langgraph-checkpoint-sqlite`.
|
||||
- PostgreSQL uses `langgraph-checkpoint-postgres` and requires a string `postgresql://...` connection URL.
|
||||
- MySQL uses `langgraph-checkpoint-mysql` and requires a string MySQL connection URL.
|
||||
|
||||
SQLAlchemy may use async driver URLs such as `postgresql+asyncpg://...` or `mysql+aiomysql://...`, but LangGraph checkpointer constructors expect plain string connection URLs. This conversion belongs inside the storage driver implementation.
|
||||
|
||||
## JSON Metadata Filtering
|
||||
|
||||
Thread metadata search supports dialect-aware JSON filtering through `store.persistence.json_compat`.
|
||||
|
||||
The matcher supports:
|
||||
|
||||
- `None`
|
||||
- `bool`
|
||||
- `int`
|
||||
- `float`
|
||||
- `str`
|
||||
|
||||
It rejects:
|
||||
|
||||
- unsafe keys
|
||||
- nested JSON path expressions
|
||||
- dict/list values
|
||||
- integers outside signed 64-bit range
|
||||
|
||||
This prevents SQL/JSON path injection, avoids compiled-cache type drift, and preserves type semantics such as `True != 1` and explicit JSON `null` not matching a missing key.
|
||||
|
||||
## Step-by-Step Implementation Plan
|
||||
|
||||
### Step 1: Introduce Storage Package Foundation
|
||||
|
||||
- Add `backend/packages/storage`.
|
||||
- Add storage config models.
|
||||
- Add `AppPersistence`.
|
||||
- Add SQLite/PostgreSQL/MySQL persistence drivers.
|
||||
- Add repository contracts, models, DB implementations, and factory helpers.
|
||||
- Add package dependency wiring.
|
||||
- Exclude cron persistence.
|
||||
|
||||
### Step 2: Harden Storage Backend Compatibility
|
||||
|
||||
- Validate SQLite setup and repository behavior.
|
||||
- Validate PostgreSQL and MySQL with local E2E tests.
|
||||
- Fix checkpointer connection-string compatibility.
|
||||
- Fix PostgreSQL locking and aggregation differences.
|
||||
- Add dialect-aware JSON metadata filtering.
|
||||
|
||||
### Step 3: Add App Infra Adapters
|
||||
|
||||
- Add `backend/app/infra/storage`.
|
||||
- Implement app-facing repositories that own session lifecycle.
|
||||
- Keep storage contracts as the only data access boundary.
|
||||
- Add legacy compatibility adapters for existing app/gateway method shapes.
|
||||
- Keep app/gateway imports out of `packages/storage`.
|
||||
|
||||
### Step 4: Switch FastAPI Runtime Injection
|
||||
|
||||
- Initialize storage persistence in FastAPI startup/lifespan.
|
||||
- Attach `persistence`, `checkpointer`, and `session_factory` to `app.state`.
|
||||
- Preserve existing external state names:
|
||||
- `run_store`
|
||||
- `feedback_repo`
|
||||
- `thread_store`
|
||||
- `run_event_store`
|
||||
- `checkpointer`
|
||||
- `session_factory`
|
||||
- Start with user/auth provider construction, then migrate run/thread/feedback/run_event.
|
||||
|
||||
### Step 5: Router and Auth Compatibility
|
||||
|
||||
- Ensure routers consume app-facing adapters, not storage DB classes.
|
||||
- Ensure auth providers depend on user repository contracts.
|
||||
- Keep router response shapes unchanged.
|
||||
- Add focused auth/admin/router regression tests.
|
||||
|
||||
### Step 6: Cleanup Legacy Persistence
|
||||
|
||||
- Compare old persistence usage after app/gateway migration.
|
||||
- Remove unused old repository implementations only after all call sites move.
|
||||
- Keep compatibility shims only where needed for a transition window.
|
||||
- Delete memory backend paths from storage-owned durable persistence.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
Unit tests should cover:
|
||||
|
||||
- config parsing
|
||||
- persistence setup
|
||||
- table creation
|
||||
- repository CRUD/query behavior
|
||||
- typed JSON metadata filtering
|
||||
- dialect SQL compilation
|
||||
- cron exclusion
|
||||
|
||||
E2E tests should cover:
|
||||
|
||||
- SQLite persistence setup
|
||||
- PostgreSQL temporary database setup
|
||||
- MySQL temporary database setup
|
||||
- repository contract behavior across all supported SQL backends
|
||||
- JSON/Unicode round trip
|
||||
- rollback behavior
|
||||
- persistence close/cleanup
|
||||
|
||||
E2E tests may remain local-only if CI does not provide PostgreSQL/MySQL services.
|
||||
@@ -0,0 +1,401 @@
|
||||
# Storage Package 设计文档
|
||||
|
||||
## 背景
|
||||
|
||||
DeerFlow 当前有多类持久化职责分散在 app、gateway、runtime 和旧 persistence 模块中。这会带来几个问题:
|
||||
|
||||
- routers 和 runtime services 容易依赖具体 persistence 实现,而不是稳定契约。
|
||||
- user/auth、run metadata、thread metadata、feedback、run events、checkpointer setup 的初始化路径不统一。
|
||||
- memory、SQLite、PostgreSQL 相关路径中存在部分重复逻辑。
|
||||
- app 层代码和 storage 层代码耦合,导致增量迁移困难。
|
||||
- 增加或验证新的 SQL backend 时,需要改动 app/runtime,而不是只改 storage package。
|
||||
|
||||
引入 storage package 的目标,是把应用数据持久化抽象成 package 级能力,并提供明确契约、清晰边界和 SQL backend 兼容性。
|
||||
|
||||
## 目标
|
||||
|
||||
- 新增独立的 `packages/storage`,负责 durable application data。
|
||||
- 通过统一 persistence 构造流程支持 SQLite、PostgreSQL、MySQL。
|
||||
- 保持 LangGraph checkpointer 与同一个数据库 backend 兼容。
|
||||
- 将 repository contracts 作为 package 对外唯一数据访问边界。
|
||||
- app 层通过 `app.infra.storage` 适配 storage,而不是直接依赖 storage DB 实现类。
|
||||
- 支持 app/gateway 后续小步迁移,避免一次性大重构。
|
||||
|
||||
## 非目标
|
||||
|
||||
- 第一阶段不删除旧 persistence。
|
||||
- 不让 routers 直接依赖 storage package models。
|
||||
- 不让 app routers 管理 SQLAlchemy sessions。
|
||||
- cron persistence 不属于 storage package 基础迁移范围。
|
||||
- memory backend 不属于 durable storage package。若 app runtime 仍需要 memory 兼容,应放在 `packages/storage` 之外。
|
||||
|
||||
## Storage 设计理念
|
||||
|
||||
### Package 自己负责 Durable Storage
|
||||
|
||||
`packages/storage` 负责应用数据的 durable persistence,包括:
|
||||
|
||||
- storage 持久化配置
|
||||
- SQLAlchemy models
|
||||
- repository contracts 和 DTOs
|
||||
- SQL repository 实现
|
||||
- persistence factory functions
|
||||
- 面向现有 config 的兼容初始化入口
|
||||
|
||||
该 package 不应该 import `app.gateway`、routers、auth providers 或 runtime 中的 gateway 对象。
|
||||
|
||||
### SQL Backend 兼容
|
||||
|
||||
该 package 支持三种 SQL backend:
|
||||
|
||||
- SQLite:本地或单节点部署
|
||||
- PostgreSQL:生产多节点部署
|
||||
- MySQL:使用 MySQL 作为标准数据库的部署
|
||||
|
||||
backend 差异在 storage package 内部处理:
|
||||
|
||||
- SQLAlchemy async engine URL 构造
|
||||
- LangGraph checkpointer 连接串兼容
|
||||
- SQLite/PostgreSQL/MySQL 的 JSON metadata filter
|
||||
- 不同 SQL 方言在 locking、aggregation、JSON 类型语义上的差异
|
||||
|
||||
### 统一 Persistence Bundle
|
||||
|
||||
Storage 初始化返回 `AppPersistence` bundle:
|
||||
|
||||
```python
|
||||
@dataclass(slots=True)
|
||||
class AppPersistence:
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
setup: Callable[[], Awaitable[None]]
|
||||
aclose: Callable[[], Awaitable[None]]
|
||||
```
|
||||
|
||||
app runtime 只需要初始化一次 persistence,调用 `setup()`,然后注入:
|
||||
|
||||
- `checkpointer`
|
||||
- `session_factory`
|
||||
- repository adapters
|
||||
|
||||
这样 checkpointer 和应用数据可以对齐到同一个 backend,同时 routers 不需要理解数据库配置。
|
||||
|
||||
## Package 结构
|
||||
|
||||
```text
|
||||
backend/packages/storage/
|
||||
store/
|
||||
config/
|
||||
storage_config.py
|
||||
app_config.py
|
||||
persistence/
|
||||
factory.py
|
||||
types.py
|
||||
base_model.py
|
||||
json_compat.py
|
||||
drivers/
|
||||
sqlite.py
|
||||
postgres.py
|
||||
mysql.py
|
||||
repositories/
|
||||
contracts/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
models/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
db/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
factory.py
|
||||
```
|
||||
|
||||
## Persistence 构造
|
||||
|
||||
storage 的主要入口:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_storage_config
|
||||
|
||||
persistence = await create_persistence_from_storage_config(storage_config)
|
||||
await persistence.setup()
|
||||
```
|
||||
|
||||
为了兼容现有 app database config,也提供:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
|
||||
persistence = await create_persistence_from_database_config(config.database)
|
||||
await persistence.setup()
|
||||
```
|
||||
|
||||
预期 app startup 流程:
|
||||
|
||||
```python
|
||||
persistence = await create_persistence_from_database_config(config.database)
|
||||
await persistence.setup()
|
||||
|
||||
app.state.persistence = persistence
|
||||
app.state.checkpointer = persistence.checkpointer
|
||||
app.state.session_factory = persistence.session_factory
|
||||
```
|
||||
|
||||
预期 app shutdown 流程:
|
||||
|
||||
```python
|
||||
await app.state.persistence.aclose()
|
||||
```
|
||||
|
||||
## Repository 契约设计
|
||||
|
||||
Repository contracts 是 storage package 对外公开的数据访问边界。它们位于 `store.repositories.contracts`,并通过 `store.repositories` re-export。
|
||||
|
||||
主要契约包括:
|
||||
|
||||
- `UserRepositoryProtocol`
|
||||
- `RunRepositoryProtocol`
|
||||
- `ThreadMetaRepositoryProtocol`
|
||||
- `FeedbackRepositoryProtocol`
|
||||
- `RunEventRepositoryProtocol`
|
||||
|
||||
每组契约包含:
|
||||
|
||||
- 输入 DTO,例如 `UserCreate`、`RunCreate`、`ThreadMetaCreate`
|
||||
- 输出 DTO,例如 `User`、`Run`、`ThreadMeta`
|
||||
- repository protocol methods
|
||||
- 必要的领域异常,例如 `InvalidMetadataFilterError`
|
||||
|
||||
Repository 通过 session 构造:
|
||||
|
||||
```python
|
||||
from store.repositories import build_run_repository
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
run = await repo.get_run(run_id)
|
||||
```
|
||||
|
||||
这样可以让 transaction ownership 保持明确。storage package 不通过全局 singleton 隐式隐藏 commit 或 session 生命周期。
|
||||
|
||||
## App/Infra 调用契约
|
||||
|
||||
app 层不应该直接调用 `store.repositories.db.*`。预期的 app 边界是 `app.infra.storage`。
|
||||
|
||||
`app.infra.storage` 负责:
|
||||
|
||||
- 从 FastAPI runtime 初始化中接收 `session_factory`
|
||||
- 为 app-facing repository methods 管理 session 生命周期
|
||||
- 在必要时将 storage DTOs 转成 app/gateway DTOs
|
||||
- 迁移期间保留现有 app-facing 名称
|
||||
- 依赖 storage repository protocols,而不是具体 DB classes
|
||||
|
||||
预期 adapter 模式:
|
||||
|
||||
```python
|
||||
class StorageRunRepository(RunRepositoryProtocol):
|
||||
def __init__(self, session_factory):
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def get_run(self, run_id: str):
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
return await repo.get_run(run_id)
|
||||
```
|
||||
|
||||
为了兼容 gateway,app state 可以暂时保持现有名字,只替换内部实现:
|
||||
|
||||
```python
|
||||
app.state.run_store = StorageRunStore(run_repository)
|
||||
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
|
||||
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
|
||||
app.state.run_event_store = StorageRunEventStore(run_event_repository)
|
||||
app.state.checkpointer = persistence.checkpointer
|
||||
app.state.session_factory = persistence.session_factory
|
||||
```
|
||||
|
||||
app-facing objects 可以在迁移期间保留旧方法名,但内部数据访问必须经过 storage contracts。
|
||||
|
||||
## 边界规则
|
||||
|
||||
### 允许调用的范围
|
||||
|
||||
storage package 调用方可以使用:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.persistence import create_persistence_from_storage_config
|
||||
from store.repositories import build_run_repository
|
||||
from store.repositories import build_user_repository
|
||||
from store.repositories import build_thread_meta_repository
|
||||
from store.repositories import build_feedback_repository
|
||||
from store.repositories import build_run_event_repository
|
||||
from store.repositories import RunRepositoryProtocol
|
||||
from store.repositories import UserRepositoryProtocol
|
||||
```
|
||||
|
||||
app 层应该使用:
|
||||
|
||||
```python
|
||||
from app.infra.storage import StorageRunRepository
|
||||
from app.infra.storage import StorageUserDataRepository
|
||||
from app.infra.storage import StorageThreadMetaRepository
|
||||
from app.infra.storage import StorageFeedbackRepository
|
||||
from app.infra.storage import StorageRunEventRepository
|
||||
```
|
||||
|
||||
### 禁止调用的范围
|
||||
|
||||
app/gateway/router/auth 代码不应该 import:
|
||||
|
||||
```python
|
||||
from store.repositories.db import DbRunRepository
|
||||
from store.repositories.models import Run
|
||||
from store.persistence.base_model import MappedBase
|
||||
```
|
||||
|
||||
routers 禁止:
|
||||
|
||||
- 创建 SQLAlchemy engines
|
||||
- 直接创建 SQLAlchemy sessions
|
||||
- 直接调用 storage DB repository classes
|
||||
- 直接 commit/rollback storage transactions,除非这是 infra adapter 明确管理的范围
|
||||
- 依赖 storage SQLAlchemy model classes
|
||||
|
||||
storage package 禁止 import:
|
||||
|
||||
```python
|
||||
import app.gateway
|
||||
import app.infra
|
||||
import deerflow.runtime
|
||||
```
|
||||
|
||||
依赖方向必须是:
|
||||
|
||||
```text
|
||||
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
|
||||
```
|
||||
|
||||
禁止反向依赖。
|
||||
|
||||
## Checkpointer 兼容
|
||||
|
||||
storage persistence bundle 会同时初始化 LangGraph checkpointer 和应用数据持久化。
|
||||
|
||||
backend 说明:
|
||||
|
||||
- SQLite 使用 `langgraph-checkpoint-sqlite`。
|
||||
- PostgreSQL 使用 `langgraph-checkpoint-postgres`,需要字符串形式的 `postgresql://...` 连接串。
|
||||
- MySQL 使用 `langgraph-checkpoint-mysql`,需要字符串形式的 MySQL 连接串。
|
||||
|
||||
SQLAlchemy 可以使用 `postgresql+asyncpg://...` 或 `mysql+aiomysql://...` 这类 async driver URL,但 LangGraph checkpointer 构造函数需要普通字符串连接串。这个转换应该封装在 storage driver implementation 内部。
|
||||
|
||||
## JSON Metadata Filtering
|
||||
|
||||
Thread metadata search 通过 `store.persistence.json_compat` 支持跨方言 JSON filtering。
|
||||
|
||||
支持的 filter value 类型:
|
||||
|
||||
- `None`
|
||||
- `bool`
|
||||
- `int`
|
||||
- `float`
|
||||
- `str`
|
||||
|
||||
拒绝:
|
||||
|
||||
- unsafe keys
|
||||
- nested JSON path expressions
|
||||
- dict/list values
|
||||
- 超出 signed 64-bit 范围的整数
|
||||
|
||||
这样可以避免 SQL/JSON path injection,避免 compiled-cache 类型漂移,并保留类型语义,例如 `True != 1`,显式 JSON `null` 不等于 missing key。
|
||||
|
||||
## 分步实现方案
|
||||
|
||||
### 第 1 步:新增 Storage Package 基础
|
||||
|
||||
- 新增 `backend/packages/storage`。
|
||||
- 增加 storage config models。
|
||||
- 增加 `AppPersistence`。
|
||||
- 增加 SQLite/PostgreSQL/MySQL persistence drivers。
|
||||
- 增加 repository contracts、models、DB implementations 和 factory helpers。
|
||||
- 接入 package dependency。
|
||||
- 排除 cron persistence。
|
||||
|
||||
### 第 2 步:补齐 Storage Backend 兼容性
|
||||
|
||||
- 验证 SQLite setup 和 repository 行为。
|
||||
- 使用本地 E2E 验证 PostgreSQL 和 MySQL。
|
||||
- 修复 checkpointer 连接串兼容。
|
||||
- 修复 PostgreSQL locking 和 aggregation 差异。
|
||||
- 增加跨方言 JSON metadata filtering。
|
||||
|
||||
### 第 3 步:新增 App Infra Adapters
|
||||
|
||||
- 新增 `backend/app/infra/storage`。
|
||||
- 实现 app-facing repositories,由它们管理 session 生命周期。
|
||||
- 保持 storage contracts 作为唯一数据访问边界。
|
||||
- 为现有 app/gateway method shape 增加兼容 adapters。
|
||||
- 避免 `packages/storage` import app/gateway。
|
||||
|
||||
### 第 4 步:切换 FastAPI Runtime 注入
|
||||
|
||||
- 在 FastAPI startup/lifespan 中初始化 storage persistence。
|
||||
- 将 `persistence`、`checkpointer`、`session_factory` 注入 `app.state`。
|
||||
- 暂时保留现有对外 state 名称:
|
||||
- `run_store`
|
||||
- `feedback_repo`
|
||||
- `thread_store`
|
||||
- `run_event_store`
|
||||
- `checkpointer`
|
||||
- `session_factory`
|
||||
- 先切 user/auth provider 构造,再逐步迁移 run/thread/feedback/run_event。
|
||||
|
||||
### 第 5 步:Router 和 Auth 兼容
|
||||
|
||||
- 确保 routers 消费 app-facing adapters,而不是 storage DB classes。
|
||||
- 确保 auth providers 依赖 user repository contracts。
|
||||
- 保持 router response shapes 不变。
|
||||
- 增加 auth/admin/router regression tests。
|
||||
|
||||
### 第 6 步:清理旧 Persistence
|
||||
|
||||
- app/gateway 迁移完成后,再比较旧 persistence usage。
|
||||
- 所有 call sites 迁移完成后,再删除未使用的旧 repository implementations。
|
||||
- 只在必要时保留短期 compatibility shims。
|
||||
- 从 storage-owned durable persistence 中移除 memory backend 路径。
|
||||
|
||||
## 测试策略
|
||||
|
||||
单测应覆盖:
|
||||
|
||||
- config parsing
|
||||
- persistence setup
|
||||
- table creation
|
||||
- repository CRUD/query behavior
|
||||
- typed JSON metadata filtering
|
||||
- dialect SQL compilation
|
||||
- cron exclusion
|
||||
|
||||
E2E 应覆盖:
|
||||
|
||||
- SQLite persistence setup
|
||||
- PostgreSQL temporary database setup
|
||||
- MySQL temporary database setup
|
||||
- 所有支持 SQL backend 下的 repository contract 行为
|
||||
- JSON/Unicode round trip
|
||||
- rollback behavior
|
||||
- persistence close/cleanup
|
||||
|
||||
如果 CI 暂时没有 PostgreSQL/MySQL services,E2E 可以先作为 local-only 验证保留。
|
||||
@@ -40,6 +40,15 @@ class MemoryUpdateQueue:
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
@staticmethod
|
||||
def _queue_key(
|
||||
thread_id: str,
|
||||
user_id: str | None,
|
||||
agent_name: str | None,
|
||||
) -> tuple[str, str | None, str | None]:
|
||||
"""Return the debounce identity for a memory update target."""
|
||||
return (thread_id, user_id, agent_name)
|
||||
|
||||
def add(
|
||||
self,
|
||||
thread_id: str,
|
||||
@@ -115,8 +124,9 @@ class MemoryUpdateQueue:
|
||||
correction_detected: bool,
|
||||
reinforcement_detected: bool,
|
||||
) -> None:
|
||||
queue_key = self._queue_key(thread_id, user_id, agent_name)
|
||||
existing_context = next(
|
||||
(context for context in self._queue if context.thread_id == thread_id),
|
||||
(context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key),
|
||||
None,
|
||||
)
|
||||
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
|
||||
@@ -130,7 +140,7 @@ class MemoryUpdateQueue:
|
||||
reinforcement_detected=merged_reinforcement_detected,
|
||||
)
|
||||
|
||||
self._queue = [c for c in self._queue if c.thread_id != thread_id]
|
||||
self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key]
|
||||
self._queue.append(context)
|
||||
|
||||
def _reset_timer(self) -> None:
|
||||
|
||||
@@ -6,6 +6,7 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_
|
||||
from deerflow.agents.memory.queue import get_memory_queue
|
||||
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
|
||||
from deerflow.config.memory_config import get_memory_config
|
||||
from deerflow.runtime.user_context import resolve_runtime_user_id
|
||||
|
||||
|
||||
def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||
@@ -21,11 +22,13 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
|
||||
|
||||
correction_detected = detect_correction(filtered_messages)
|
||||
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
|
||||
user_id = resolve_runtime_user_id(event.runtime)
|
||||
queue = get_memory_queue()
|
||||
queue.add_nowait(
|
||||
thread_id=event.thread_id,
|
||||
messages=filtered_messages,
|
||||
agent_name=event.agent_name,
|
||||
user_id=user_id,
|
||||
correction_detected=correction_detected,
|
||||
reinforcement_detected=reinforcement_detected,
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ import logging
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy import delete, func, select, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
from deerflow.persistence.models.run_event import RunEventRow
|
||||
@@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore):
|
||||
user = get_current_user()
|
||||
return str(user.id) if user is not None else None
|
||||
|
||||
@staticmethod
|
||||
async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None:
|
||||
"""Return the current max seq while serializing writers per thread.
|
||||
|
||||
PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate
|
||||
results are not lockable rows. As a release-safe workaround, take a
|
||||
transaction-level advisory lock keyed by thread_id before reading the
|
||||
aggregate. Other dialects keep the existing row-locking statement.
|
||||
"""
|
||||
stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)
|
||||
bind = session.get_bind()
|
||||
dialect_name = bind.dialect.name if bind is not None else ""
|
||||
|
||||
if dialect_name == "postgresql":
|
||||
await session.execute(
|
||||
text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"),
|
||||
{"thread_id": thread_id},
|
||||
)
|
||||
return await session.scalar(stmt)
|
||||
|
||||
return await session.scalar(stmt.with_for_update())
|
||||
|
||||
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
|
||||
"""Write a single event — low-frequency path only.
|
||||
|
||||
@@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore):
|
||||
user_id = self._user_id_from_context()
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = (max_seq or 0) + 1
|
||||
row = RunEventRow(
|
||||
thread_id=thread_id,
|
||||
@@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore):
|
||||
async with self._sf() as session:
|
||||
async with session.begin():
|
||||
# Get max seq for the thread (assume all events in batch belong to same thread).
|
||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||
thread_id = events[0]["thread_id"]
|
||||
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||
max_seq = await self._max_seq_for_thread(session, thread_id)
|
||||
seq = max_seq or 0
|
||||
rows = []
|
||||
for e in events:
|
||||
|
||||
@@ -35,7 +35,7 @@ def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
|
||||
if app_config is None:
|
||||
try:
|
||||
app_config = get_app_config()
|
||||
except FileNotFoundError:
|
||||
except (FileNotFoundError, ValueError):
|
||||
return False
|
||||
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
|
||||
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
[project]
|
||||
name = "deerflow-storage"
|
||||
version = "0.1.0"
|
||||
description = "DeerFlow storage framework"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"dotenv>=0.9.9",
|
||||
"pydantic>=2.12.5",
|
||||
"pyyaml>=6.0.3",
|
||||
"sqlalchemy[asyncio]>=2.0,<3.0",
|
||||
"alembic>=1.13",
|
||||
"langgraph>=1.1.9",
|
||||
]
|
||||
[project.optional-dependencies]
|
||||
postgres = [
|
||||
"asyncpg>=0.29",
|
||||
"langgraph-checkpoint-postgres>=3.0.5",
|
||||
"psycopg[binary]>=3.3.3",
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
mysql = [
|
||||
"aiomysql>=0.2",
|
||||
"langgraph-checkpoint-mysql>=3.0.0",
|
||||
]
|
||||
sqlite = [
|
||||
"aiosqlite>=0.22.1",
|
||||
"langgraph-checkpoint-sqlite>=3.0.3"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["store"]
|
||||
@@ -0,0 +1,5 @@
|
||||
from .enums import DataBaseType
|
||||
|
||||
__all__ = [
|
||||
"DataBaseType",
|
||||
]
|
||||
@@ -0,0 +1,41 @@
|
||||
from enum import Enum
|
||||
from enum import IntEnum as SourceIntEnum
|
||||
from enum import StrEnum as SourceStrEnum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
T = TypeVar("T", bound=Enum)
|
||||
|
||||
|
||||
class _EnumBase:
|
||||
"""Base enum class with common utility methods."""
|
||||
|
||||
@classmethod
|
||||
def get_member_keys(cls) -> list[str]:
|
||||
"""Return a list of enum member names."""
|
||||
return list(cls.__members__.keys())
|
||||
|
||||
@classmethod
|
||||
def get_member_values(cls) -> list:
|
||||
"""Return a list of enum member values."""
|
||||
return [item.value for item in cls.__members__.values()]
|
||||
|
||||
@classmethod
|
||||
def get_member_dict(cls) -> dict[str, Any]:
|
||||
"""Return a dict mapping member names to values."""
|
||||
return {name: item.value for name, item in cls.__members__.items()}
|
||||
|
||||
|
||||
class IntEnum(_EnumBase, SourceIntEnum):
|
||||
"""Integer enum base class."""
|
||||
|
||||
|
||||
class StrEnum(_EnumBase, SourceStrEnum):
|
||||
"""String enum base class."""
|
||||
|
||||
|
||||
class DataBaseType(StrEnum):
|
||||
"""Database type."""
|
||||
|
||||
sqlite = "sqlite"
|
||||
mysql = "mysql"
|
||||
postgresql = "postgresql"
|
||||
@@ -0,0 +1,286 @@
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from store.config.storage_config import StorageConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
cwd = Path.cwd().resolve()
|
||||
candidates = (
|
||||
cwd / "config.yaml",
|
||||
backend_dir / "config.yaml",
|
||||
repo_root / "config.yaml",
|
||||
)
|
||||
return tuple(dict.fromkeys(candidates))
|
||||
|
||||
|
||||
def _storage_from_database_config(config_data: dict[str, Any]) -> None:
|
||||
"""Keep the existing public `database:` config compatible with storage."""
|
||||
if "storage" in config_data:
|
||||
return
|
||||
|
||||
database = config_data.get("database")
|
||||
if not isinstance(database, dict):
|
||||
return
|
||||
|
||||
backend = database.get("backend")
|
||||
if backend == "memory":
|
||||
raise ValueError("database.backend='memory' is not supported by storage; handle memory mode before loading storage config")
|
||||
|
||||
storage: dict[str, Any] = {
|
||||
"driver": "postgres" if backend == "postgres" else backend,
|
||||
"sqlite_dir": database.get("sqlite_dir", ".deer-flow/data"),
|
||||
"echo_sql": database.get("echo_sql", False),
|
||||
"pool_size": database.get("pool_size", 5),
|
||||
}
|
||||
|
||||
postgres_url = database.get("postgres_url")
|
||||
if backend == "postgres" and isinstance(postgres_url, str) and postgres_url:
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
parsed = make_url(postgres_url)
|
||||
storage["database_url"] = postgres_url
|
||||
storage.update(
|
||||
{
|
||||
"username": parsed.username or "",
|
||||
"password": parsed.password or "",
|
||||
"host": parsed.host or "localhost",
|
||||
"port": parsed.port or 5432,
|
||||
"db_name": parsed.database or "deerflow",
|
||||
}
|
||||
)
|
||||
|
||||
config_data["storage"] = storage
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""DeerFlow application configuration."""
|
||||
|
||||
timezone: str = Field(default="UTC", description="Timezone for scheduling and timestamps (e.g. 'UTC', 'America/New_York')")
|
||||
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
|
||||
storage: StorageConfig = Field(default=StorageConfig())
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||
"""Resolve the config file path.
|
||||
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
|
||||
return path
|
||||
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
for path in _default_config_candidates():
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
"""Load and validate config from YAML. See `resolve_config_path` for path resolution."""
|
||||
resolved_path = cls.resolve_config_path(config_path)
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
config_data = yaml.safe_load(f) or {}
|
||||
|
||||
cls._check_config_version(config_data, resolved_path)
|
||||
|
||||
config_data = cls.resolve_env_variables(config_data)
|
||||
_storage_from_database_config(config_data)
|
||||
|
||||
if os.getenv("TIMEZONE"):
|
||||
config_data["timezone"] = os.getenv("TIMEZONE")
|
||||
|
||||
result = cls.model_validate(config_data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
|
||||
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
|
||||
|
||||
Emits a warning if the user's config_version is lower than the example's.
|
||||
Missing config_version is treated as version 0 (pre-versioning).
|
||||
"""
|
||||
try:
|
||||
user_version = int(config_data.get("config_version", 0))
|
||||
except (TypeError, ValueError):
|
||||
user_version = 0
|
||||
|
||||
# Find config.example.yaml by searching config.yaml's directory and its parents
|
||||
example_path = None
|
||||
search_dir = config_path.parent
|
||||
for _ in range(5): # search up to 5 levels
|
||||
candidate = search_dir / "config.example.yaml"
|
||||
if candidate.exists():
|
||||
example_path = candidate
|
||||
break
|
||||
parent = search_dir.parent
|
||||
if parent == search_dir:
|
||||
break
|
||||
search_dir = parent
|
||||
if example_path is None:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(example_path, encoding="utf-8") as f:
|
||||
example_data = yaml.safe_load(f)
|
||||
raw = example_data.get("config_version", 0) if example_data else 0
|
||||
try:
|
||||
example_version = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
example_version = 0
|
||||
except Exception:
|
||||
return
|
||||
|
||||
if user_version < example_version:
|
||||
logger.warning(
|
||||
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.",
|
||||
user_version,
|
||||
example_version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_env_variables(cls, config: Any) -> Any:
|
||||
"""Recursively replace $VAR strings with their environment variable values (e.g. $OPENAI_API_KEY)."""
|
||||
if isinstance(config, str):
|
||||
if config.startswith("$"):
|
||||
env_value = os.getenv(config[1:])
|
||||
if env_value is None:
|
||||
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
|
||||
return env_value
|
||||
return config
|
||||
elif isinstance(config, dict):
|
||||
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
|
||||
elif isinstance(config, list):
|
||||
return [cls.resolve_env_variables(item) for item in config]
|
||||
return config
|
||||
|
||||
|
||||
_app_config: AppConfig | None = None
|
||||
_app_config_path: Path | None = None
|
||||
_app_config_mtime: float | None = None
|
||||
_app_config_is_custom = False
|
||||
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
|
||||
|
||||
|
||||
def _get_config_mtime(config_path: Path) -> float | None:
|
||||
"""Get the modification time of a config file if it exists."""
|
||||
try:
|
||||
return config_path.stat().st_mtime
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Load config from disk and refresh cache metadata."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path(config_path)
|
||||
_app_config = AppConfig.from_file(str(resolved_path))
|
||||
_app_config_path = resolved_path
|
||||
_app_config_mtime = _get_config_mtime(resolved_path)
|
||||
_app_config_is_custom = False
|
||||
return _app_config
|
||||
|
||||
|
||||
def get_app_config() -> AppConfig:
|
||||
"""Get the DeerFlow config instance.
|
||||
|
||||
Returns a cached singleton instance and automatically reloads it when the
|
||||
underlying config file path or modification time changes. Use
|
||||
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
|
||||
the cache.
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime
|
||||
|
||||
runtime_override = _current_app_config.get()
|
||||
if runtime_override is not None:
|
||||
return runtime_override
|
||||
|
||||
if _app_config is not None and _app_config_is_custom:
|
||||
return _app_config
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path()
|
||||
current_mtime = _get_config_mtime(resolved_path)
|
||||
|
||||
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
|
||||
if should_reload:
|
||||
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
|
||||
logger.info(
|
||||
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
|
||||
_app_config_mtime,
|
||||
current_mtime,
|
||||
)
|
||||
_load_and_cache_app_config(str(resolved_path))
|
||||
return _app_config
|
||||
|
||||
|
||||
def reload_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Force reload from file and update the cache."""
|
||||
return _load_and_cache_app_config(config_path)
|
||||
|
||||
|
||||
def reset_app_config() -> None:
|
||||
"""Clear the cache so the next `get_app_config()` reloads from file."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = None
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = False
|
||||
|
||||
|
||||
def set_app_config(config: AppConfig) -> None:
|
||||
"""Inject a config instance directly, bypassing file loading (for testing)."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = config
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = True
|
||||
|
||||
|
||||
def peek_current_app_config() -> AppConfig | None:
|
||||
"""Return the runtime-scoped AppConfig override, if one is active."""
|
||||
return _current_app_config.get()
|
||||
|
||||
|
||||
def push_current_app_config(config: AppConfig) -> None:
|
||||
"""Push a runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
_current_app_config_stack.set(stack + (_current_app_config.get(),))
|
||||
_current_app_config.set(config)
|
||||
|
||||
|
||||
def pop_current_app_config() -> None:
|
||||
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
if not stack:
|
||||
_current_app_config.set(None)
|
||||
return
|
||||
previous = stack[-1]
|
||||
_current_app_config_stack.set(stack[:-1])
|
||||
_current_app_config.set(previous)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Unified storage backend configuration for checkpointer and application data.
|
||||
|
||||
SQLite: checkpointer → {sqlite_dir}/checkpoints.db, app → {sqlite_dir}/deerflow.db
|
||||
(separate files to avoid write-lock contention)
|
||||
Postgres: shared URL, independent connection pools per layer.
|
||||
|
||||
Sensitive values use $VAR syntax resolved by AppConfig.resolve_env_variables()
|
||||
before this config is instantiated.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _strip_legacy_state_prefix(path: str) -> str:
|
||||
"""Keep old .deer-flow/* config values compatible with Paths.base_dir."""
|
||||
prefix = ".deer-flow/"
|
||||
if path == ".deer-flow":
|
||||
return "."
|
||||
if path.startswith(prefix):
|
||||
return path[len(prefix) :]
|
||||
return path
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field(
|
||||
default="sqlite",
|
||||
description="Storage driver for both checkpointer and application data. 'sqlite' for single-node deployment (default),'postgres' for production multi-node deployment, 'mysql' for MySQL databases.",
|
||||
)
|
||||
sqlite_dir: str = Field(
|
||||
default=".deer-flow/data",
|
||||
description="Directory for SQLite .db files (sqlite driver only).",
|
||||
)
|
||||
username: str = Field(default="", description="db username ")
|
||||
password: str = Field(default="", description="db password. Use $VAR syntax in config.yaml to read from .env.")
|
||||
host: str = Field(default="localhost", description="db host.")
|
||||
port: int = Field(default=5432, description="db port.")
|
||||
db_name: str = Field(default="deerflow", description="db database name.")
|
||||
database_url: str = Field(default="", description="Complete SQLAlchemy database URL. Takes precedence for non-SQLite drivers.")
|
||||
sqlite_db_path: str = Field(default=".deer-flow/data", description="Directory for SQLite .db files (sqlite driver only).")
|
||||
echo_sql: bool = Field(default=False, description="Log all SQL statements (debug only).")
|
||||
pool_size: int = Field(default=5, description="Connection pool size per layer.")
|
||||
|
||||
# -- Derived helpers (not user-configured) --
|
||||
|
||||
@property
|
||||
def _resolved_sqlite_dir(self) -> str:
|
||||
"""Resolve sqlite_dir to an absolute path under DeerFlow's base dir."""
|
||||
from pathlib import Path
|
||||
|
||||
path = Path(self.sqlite_dir)
|
||||
if path.is_absolute():
|
||||
return str(path.resolve())
|
||||
|
||||
try:
|
||||
from deerflow.config.paths import resolve_path
|
||||
|
||||
return str(resolve_path(_strip_legacy_state_prefix(self.sqlite_dir)))
|
||||
except ImportError:
|
||||
return str(path.resolve())
|
||||
|
||||
@property
|
||||
def sqlite_storage_path(self) -> str:
|
||||
"""SQLite file path for storage-owned app data and checkpointer."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
||||
@@ -0,0 +1,32 @@
|
||||
from store.persistence.base_model import (
|
||||
Base,
|
||||
DataClassBase,
|
||||
DateTimeMixin,
|
||||
MappedBase,
|
||||
TimeZone,
|
||||
UniversalText,
|
||||
id_key,
|
||||
)
|
||||
|
||||
from .factory import (
|
||||
create_persistence,
|
||||
create_persistence_from_database_config,
|
||||
create_persistence_from_storage_config,
|
||||
storage_config_from_database_config,
|
||||
)
|
||||
from .types import AppPersistence
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"DataClassBase",
|
||||
"DateTimeMixin",
|
||||
"MappedBase",
|
||||
"TimeZone",
|
||||
"UniversalText",
|
||||
"id_key",
|
||||
"create_persistence",
|
||||
"create_persistence_from_database_config",
|
||||
"create_persistence_from_storage_config",
|
||||
"storage_config_from_database_config",
|
||||
"AppPersistence",
|
||||
]
|
||||
@@ -0,0 +1,111 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
||||
|
||||
from store.utils import get_timezone
|
||||
|
||||
|
||||
def current_time() -> datetime:
|
||||
return get_timezone().now()
|
||||
|
||||
|
||||
id_key = Annotated[
|
||||
int,
|
||||
mapped_column(
|
||||
BigInteger().with_variant(Integer, "sqlite"),
|
||||
primary_key=True,
|
||||
unique=True,
|
||||
index=True,
|
||||
autoincrement=True,
|
||||
sort_order=-999,
|
||||
comment="Primary key ID",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class UniversalText(TypeDecorator[str]):
|
||||
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
||||
|
||||
impl = Text
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect): # noqa: ANN001
|
||||
if dialect.name == "mysql":
|
||||
return dialect.type_descriptor(LONGTEXT())
|
||||
return dialect.type_descriptor(Text())
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
|
||||
class TimeZone(TypeDecorator[datetime]):
|
||||
"""Timezone-aware datetime type compatible with PostgreSQL and MySQL."""
|
||||
|
||||
impl = DateTime(timezone=True)
|
||||
cache_ok = True
|
||||
|
||||
@property
|
||||
def python_type(self) -> type[datetime]:
|
||||
return datetime
|
||||
|
||||
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
timezone = get_timezone()
|
||||
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
||||
value = timezone.from_datetime(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
timezone = get_timezone()
|
||||
if value is not None and value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.tz_info)
|
||||
return value
|
||||
|
||||
|
||||
class DateTimeMixin(MappedAsDataclass):
|
||||
"""Mixin that adds created_time / updated_time columns."""
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
|
||||
class MappedBase(AsyncAttrs, DeclarativeBase):
|
||||
"""Async-capable declarative base for all ORM models."""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(self) -> str:
|
||||
return self.__name__.lower()
|
||||
|
||||
@declared_attr.directive
|
||||
def __table_args__(self) -> dict:
|
||||
return {"comment": self.__doc__ or ""}
|
||||
|
||||
|
||||
class DataClassBase(MappedAsDataclass, MappedBase):
|
||||
"""Declarative base with native dataclass integration."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class Base(DataClassBase, DateTimeMixin):
|
||||
"""Declarative dataclass base with created_time / updated_time columns."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -0,0 +1,9 @@
|
||||
from .mysql import build_mysql_persistence
|
||||
from .postgres import build_postgres_persistence
|
||||
from .sqlite import build_sqlite_persistence
|
||||
|
||||
__all__ = [
|
||||
"build_postgres_persistence",
|
||||
"build_mysql_persistence",
|
||||
"build_sqlite_persistence",
|
||||
]
|
||||
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _validate_mysql_driver(db_url: URL) -> str:
|
||||
url = make_url(db_url)
|
||||
driver = url.get_driver_name()
|
||||
|
||||
if driver not in {"aiomysql", "asyncmy"}:
|
||||
raise ValueError(f"MySQL persistence requires async SQLAlchemy driver (aiomysql/asyncmy), got: {driver!r}")
|
||||
return driver
|
||||
|
||||
|
||||
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||
return db_url.render_as_string(hide_password=False)
|
||||
|
||||
|
||||
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
_validate_mysql_driver(db_url)
|
||||
|
||||
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||
return db_url.set(drivername="postgresql").render_as_string(hide_password=False)
|
||||
|
||||
|
||||
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001
|
||||
cursor = dbapi_conn.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL;")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||||
cursor.execute("PRAGMA foreign_keys=ON;")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.config.storage_config import StorageConfig
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def storage_config_from_database_config(database_config: Any) -> StorageConfig:
|
||||
"""Convert the existing public DatabaseConfig shape to StorageConfig.
|
||||
|
||||
Storage only owns durable database-backed persistence. The app bridge
|
||||
should handle memory mode before calling into this package.
|
||||
"""
|
||||
backend = getattr(database_config, "backend", None)
|
||||
if backend == "sqlite":
|
||||
return StorageConfig(
|
||||
driver="sqlite",
|
||||
sqlite_dir=getattr(database_config, "sqlite_dir", ".deer-flow/data"),
|
||||
echo_sql=getattr(database_config, "echo_sql", False),
|
||||
pool_size=getattr(database_config, "pool_size", 5),
|
||||
)
|
||||
|
||||
if backend == "postgres":
|
||||
postgres_url = getattr(database_config, "postgres_url", "")
|
||||
if not postgres_url:
|
||||
raise ValueError("database.postgres_url is required when database.backend is 'postgres'")
|
||||
parsed = make_url(postgres_url)
|
||||
return StorageConfig(
|
||||
driver="postgres",
|
||||
database_url=postgres_url,
|
||||
username=parsed.username or "",
|
||||
password=parsed.password or "",
|
||||
host=parsed.host or "localhost",
|
||||
port=parsed.port or 5432,
|
||||
db_name=parsed.database or "deerflow",
|
||||
echo_sql=getattr(database_config, "echo_sql", False),
|
||||
pool_size=getattr(database_config, "pool_size", 5),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported database backend for storage persistence: {backend!r}")
|
||||
|
||||
|
||||
def _create_database_url(storage_config: StorageConfig) -> URL:
|
||||
"""Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres)."""
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
driver = "sqlite+aiosqlite"
|
||||
elif storage_config.driver == DataBaseType.mysql:
|
||||
driver = "mysql+aiomysql"
|
||||
elif storage_config.driver in (DataBaseType.postgresql, "postgres"):
|
||||
driver = "postgresql+asyncpg"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database driver: {storage_config.driver}")
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
import os
|
||||
|
||||
db_path = storage_config.sqlite_storage_path
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
database=db_path,
|
||||
)
|
||||
elif storage_config.database_url:
|
||||
url = make_url(storage_config.database_url)
|
||||
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
|
||||
url = url.set(drivername="postgresql+asyncpg")
|
||||
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
|
||||
url = url.set(drivername="mysql+aiomysql")
|
||||
else:
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
username=storage_config.username,
|
||||
password=storage_config.password,
|
||||
host=storage_config.host,
|
||||
port=storage_config.port,
|
||||
database=storage_config.db_name or "deerflow",
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
async def create_persistence_from_storage_config(storage_config: StorageConfig) -> AppPersistence:
|
||||
from .drivers.mysql import build_mysql_persistence
|
||||
from .drivers.postgres import build_postgres_persistence
|
||||
from .drivers.sqlite import build_sqlite_persistence
|
||||
|
||||
driver = storage_config.driver
|
||||
db_url = _create_database_url(storage_config)
|
||||
|
||||
if driver in ("postgres", "postgresql"):
|
||||
return await build_postgres_persistence(
|
||||
db_url,
|
||||
echo=storage_config.echo_sql,
|
||||
pool_size=storage_config.pool_size,
|
||||
)
|
||||
|
||||
if driver == "mysql":
|
||||
return await build_mysql_persistence(
|
||||
db_url,
|
||||
echo=storage_config.echo_sql,
|
||||
pool_size=storage_config.pool_size,
|
||||
)
|
||||
|
||||
if driver == "sqlite":
|
||||
return await build_sqlite_persistence(db_url, echo=storage_config.echo_sql)
|
||||
|
||||
raise ValueError(f"Unsupported database driver: {driver}")
|
||||
|
||||
|
||||
async def create_persistence_from_database_config(database_config: Any) -> AppPersistence:
|
||||
storage_config = storage_config_from_database_config(database_config)
|
||||
return await create_persistence_from_storage_config(storage_config)
|
||||
|
||||
|
||||
async def create_persistence() -> AppPersistence:
|
||||
app_config = get_app_config()
|
||||
return await create_persistence_from_storage_config(app_config.storage)
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Dialect-aware JSON value matching for storage SQLAlchemy repositories."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import BigInteger, Float, String, bindparam
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.compiler import SQLCompiler
|
||||
from sqlalchemy.sql.expression import ColumnElement
|
||||
from sqlalchemy.sql.visitors import InternalTraversal
|
||||
from sqlalchemy.types import Boolean, TypeEngine
|
||||
|
||||
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
|
||||
|
||||
_INT64_MIN = -(2**63)
|
||||
_INT64_MAX = 2**63 - 1
|
||||
|
||||
|
||||
def validate_metadata_filter_key(key: object) -> bool:
|
||||
"""Return True when *key* is safe for JSON metadata filter SQL paths."""
|
||||
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
|
||||
|
||||
|
||||
def validate_metadata_filter_value(value: object) -> bool:
|
||||
"""Return True when *value* can be compiled into a portable JSON predicate."""
|
||||
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
|
||||
return False
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
return _INT64_MIN <= value <= _INT64_MAX
|
||||
return True
|
||||
|
||||
|
||||
class JsonMatch(ColumnElement[bool]):
|
||||
"""Dialect-portable ``column[key] == value`` for JSON columns."""
|
||||
|
||||
inherit_cache = True
|
||||
type = Boolean()
|
||||
_is_implicitly_boolean = True
|
||||
|
||||
_traverse_internals = [
|
||||
("column", InternalTraversal.dp_clauseelement),
|
||||
("key", InternalTraversal.dp_string),
|
||||
("value", InternalTraversal.dp_plain_obj),
|
||||
("value_type", InternalTraversal.dp_string),
|
||||
]
|
||||
|
||||
def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None:
|
||||
if not validate_metadata_filter_key(key):
|
||||
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
|
||||
if not validate_metadata_filter_value(value):
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
|
||||
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
|
||||
self.column = column
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.value_type = type(value).__qualname__
|
||||
super().__init__()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _Dialect:
|
||||
null_type: str
|
||||
num_types: tuple[str, ...]
|
||||
num_cast: str
|
||||
int_types: tuple[str, ...]
|
||||
int_cast: str
|
||||
int_guard: str | None
|
||||
string_type: str
|
||||
bool_type: str | None
|
||||
true_value: str
|
||||
false_value: str
|
||||
|
||||
|
||||
_SQLITE = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("integer", "real"),
|
||||
num_cast="REAL",
|
||||
int_types=("integer",),
|
||||
int_cast="INTEGER",
|
||||
int_guard=None,
|
||||
string_type="text",
|
||||
bool_type=None,
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
_POSTGRES = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("number",),
|
||||
num_cast="DOUBLE PRECISION",
|
||||
int_types=("number",),
|
||||
int_cast="BIGINT",
|
||||
int_guard="'^-?[0-9]+$'",
|
||||
string_type="string",
|
||||
bool_type="boolean",
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
_MYSQL = _Dialect(
|
||||
null_type="NULL",
|
||||
num_types=("INTEGER", "DOUBLE", "DECIMAL"),
|
||||
num_cast="DOUBLE",
|
||||
int_types=("INTEGER",),
|
||||
int_cast="SIGNED",
|
||||
int_guard=None,
|
||||
string_type="STRING",
|
||||
bool_type="BOOLEAN",
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
|
||||
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
|
||||
param = bindparam(None, value, type_=sa_type)
|
||||
return compiler.process(param, **kw)
|
||||
|
||||
|
||||
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
|
||||
if len(types) == 1:
|
||||
return f"{typeof} = '{types[0]}'"
|
||||
quoted = ", ".join(f"'{type_name}'" for type_name in types)
|
||||
return f"{typeof} IN ({quoted})"
|
||||
|
||||
|
||||
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
|
||||
if value is None:
|
||||
return f"{typeof} = '{dialect.null_type}'"
|
||||
if isinstance(value, bool):
|
||||
bool_str = dialect.true_value if value else dialect.false_value
|
||||
if dialect.bool_type is None:
|
||||
return f"{typeof} = '{bool_str}'"
|
||||
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
|
||||
if isinstance(value, int):
|
||||
bp = _bind(compiler, value, BigInteger(), **kw)
|
||||
if dialect.int_guard:
|
||||
return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
|
||||
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
|
||||
if isinstance(value, float):
|
||||
bp = _bind(compiler, value, Float(), **kw)
|
||||
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
|
||||
bp = _bind(compiler, str(value), String(), **kw)
|
||||
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
|
||||
|
||||
|
||||
@compiles(JsonMatch, "sqlite")
|
||||
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
path = f'$."{element.key}"'
|
||||
typeof = f"json_type({col}, '{path}')"
|
||||
extract = f"json_extract({col}, '{path}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch, "postgresql")
|
||||
def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
typeof = f"json_typeof({col} -> '{element.key}')"
|
||||
extract = f"({col} ->> '{element.key}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch, "mysql")
|
||||
def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
path = f'$."{element.key}"'
|
||||
typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))"
|
||||
extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch)
|
||||
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}")
|
||||
|
||||
|
||||
def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch:
|
||||
return JsonMatch(column, key, value)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .close import close_in_order
|
||||
|
||||
__all__ = ["close_in_order"]
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
AsyncCloser = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
async def close_in_order(*closers: AsyncCloser) -> None:
|
||||
"""
|
||||
Run async closers in order and raise the first error, if any.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Used to keep driver-specific close logic readable.
|
||||
- We intentionally do not stop at first failure, so later resources
|
||||
still get a chance to close.
|
||||
"""
|
||||
first_error: Exception | None = None
|
||||
|
||||
for closer in closers:
|
||||
try:
|
||||
await closer()
|
||||
except Exception as exc:
|
||||
if first_error is None:
|
||||
first_error = exc
|
||||
|
||||
if first_error is not None:
|
||||
raise first_error
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
AsyncSetup = Callable[[], Awaitable[None]]
|
||||
AsyncClose = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AppPersistence:
|
||||
"""
|
||||
Unified runtime persistence bundle.
|
||||
"""
|
||||
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
setup: AsyncSetup
|
||||
aclose: AsyncClose
|
||||
@@ -0,0 +1,53 @@
|
||||
from store.repositories.contracts import (
|
||||
Feedback,
|
||||
FeedbackAggregate,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
InvalidMetadataFilterError,
|
||||
Run,
|
||||
RunCreate,
|
||||
RunEvent,
|
||||
RunEventCreate,
|
||||
RunEventRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
User,
|
||||
UserCreate,
|
||||
UserNotFoundError,
|
||||
UserRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.factory import (
|
||||
build_feedback_repository,
|
||||
build_run_event_repository,
|
||||
build_run_repository,
|
||||
build_thread_meta_repository,
|
||||
build_user_repository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Feedback",
|
||||
"FeedbackAggregate",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"InvalidMetadataFilterError",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
"User",
|
||||
"UserCreate",
|
||||
"UserNotFoundError",
|
||||
"UserRepositoryProtocol",
|
||||
"build_run_repository",
|
||||
"build_run_event_repository",
|
||||
"build_thread_meta_repository",
|
||||
"build_feedback_repository",
|
||||
"build_user_repository",
|
||||
]
|
||||
@@ -0,0 +1,49 @@
|
||||
from store.repositories.contracts.feedback import (
|
||||
Feedback,
|
||||
FeedbackAggregate,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.run import (
|
||||
Run,
|
||||
RunCreate,
|
||||
RunRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.run_event import (
|
||||
RunEvent,
|
||||
RunEventCreate,
|
||||
RunEventRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
InvalidMetadataFilterError,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.user import (
|
||||
User,
|
||||
UserCreate,
|
||||
UserNotFoundError,
|
||||
UserRepositoryProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Feedback",
|
||||
"FeedbackAggregate",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"InvalidMetadataFilterError",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
"User",
|
||||
"UserCreate",
|
||||
"UserNotFoundError",
|
||||
"UserRepositoryProtocol",
|
||||
]
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Protocol, TypedDict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class FeedbackCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
rating: int
|
||||
user_id: str | None = None
|
||||
message_id: str | None = None
|
||||
comment: str | None = None
|
||||
|
||||
|
||||
class Feedback(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
rating: int
|
||||
user_id: str | None
|
||||
message_id: str | None
|
||||
comment: str | None
|
||||
created_time: datetime
|
||||
|
||||
|
||||
class FeedbackAggregate(TypedDict):
|
||||
run_id: str
|
||||
total: int
|
||||
positive: int
|
||||
negative: int
|
||||
|
||||
|
||||
class FeedbackRepositoryProtocol(Protocol):
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
pass
|
||||
|
||||
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
pass
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
pass
|
||||
|
||||
async def list_feedback_by_run(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
thread_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
pass
|
||||
|
||||
async def list_feedback_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
pass
|
||||
|
||||
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||
pass
|
||||
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||
pass
|
||||
|
||||
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
|
||||
pass
|
||||
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RunCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
status: str = "pending"
|
||||
model_name: str | None = None
|
||||
multitask_strategy: str = "reject"
|
||||
error: str | None = None
|
||||
follow_up_to_run_id: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
created_time: datetime | None = None
|
||||
|
||||
|
||||
class Run(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
status: str
|
||||
model_name: str | None
|
||||
multitask_strategy: str
|
||||
error: str | None
|
||||
follow_up_to_run_id: str | None
|
||||
metadata: dict[str, Any]
|
||||
kwargs: dict[str, Any]
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_tokens: int
|
||||
llm_call_count: int
|
||||
lead_agent_tokens: int
|
||||
subagent_tokens: int
|
||||
middleware_tokens: int
|
||||
message_count: int
|
||||
first_human_message: str | None
|
||||
last_ai_message: str | None
|
||||
created_time: datetime
|
||||
updated_time: datetime | None
|
||||
|
||||
|
||||
class RunRepositoryProtocol(Protocol):
|
||||
async def create_run(self, data: RunCreate) -> Run:
|
||||
pass
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
pass
|
||||
|
||||
async def list_runs_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[Run]:
|
||||
pass
|
||||
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
|
||||
pass
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
|
||||
pass
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
pass
|
||||
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RunEventCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str
|
||||
run_id: str
|
||||
user_id: str | None = None
|
||||
event_type: str
|
||||
category: str
|
||||
content: Any = ""
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: datetime | None = None
|
||||
|
||||
|
||||
class RunEvent(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
thread_id: str
|
||||
run_id: str
|
||||
user_id: str | None
|
||||
event_type: str
|
||||
category: str
|
||||
content: Any
|
||||
metadata: dict[str, Any]
|
||||
seq: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RunEventRepositoryProtocol(Protocol):
|
||||
# Sequence values are time-ordered integer cursors. The application layer
|
||||
# owns the single-writer invariant for a thread while a run is active.
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
pass
|
||||
|
||||
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
pass
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
|
||||
pass
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class InvalidMetadataFilterError(ValueError):
|
||||
"""Raised when all client-supplied metadata filters are rejected."""
|
||||
|
||||
|
||||
class ThreadMetaCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
display_name: str | None = None
|
||||
status: str = "idle"
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ThreadMeta(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
display_name: str | None
|
||||
status: str
|
||||
metadata: dict[str, Any]
|
||||
created_time: datetime
|
||||
updated_time: datetime | None
|
||||
|
||||
|
||||
class ThreadMetaRepositoryProtocol(Protocol):
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
pass
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
pass
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
pass
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class UserNotFoundError(LookupError):
|
||||
"""Raised when an update targets a user row that no longer exists."""
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: str
|
||||
email: str
|
||||
password_hash: str | None = None
|
||||
system_role: Literal["admin", "user"] = "user"
|
||||
created_at: datetime | None = None
|
||||
oauth_provider: str | None = None
|
||||
oauth_id: str | None = None
|
||||
needs_setup: bool = False
|
||||
token_version: int = 0
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
id: str
|
||||
email: str
|
||||
password_hash: str | None
|
||||
system_role: Literal["admin", "user"]
|
||||
created_at: datetime
|
||||
oauth_provider: str | None
|
||||
oauth_id: str | None
|
||||
needs_setup: bool
|
||||
token_version: int
|
||||
|
||||
|
||||
class UserRepositoryProtocol(Protocol):
|
||||
async def create_user(self, data: UserCreate) -> User:
|
||||
pass
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
pass
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
pass
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
pass
|
||||
|
||||
async def get_first_admin(self) -> User | None:
|
||||
pass
|
||||
|
||||
async def update_user(self, data: User) -> User:
|
||||
pass
|
||||
|
||||
async def count_users(self) -> int:
|
||||
pass
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
pass
|
||||
@@ -0,0 +1,13 @@
|
||||
from store.repositories.db.feedback import DbFeedbackRepository
|
||||
from store.repositories.db.run import DbRunRepository
|
||||
from store.repositories.db.run_event import DbRunEventRepository
|
||||
from store.repositories.db.thread_meta import DbThreadMetaRepository
|
||||
from store.repositories.db.user import DbUserRepository
|
||||
|
||||
__all__ = [
|
||||
"DbFeedbackRepository",
|
||||
"DbRunRepository",
|
||||
"DbRunEventRepository",
|
||||
"DbThreadMetaRepository",
|
||||
"DbUserRepository",
|
||||
]
|
||||
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import case, delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.feedback import Feedback, FeedbackAggregate, FeedbackCreate, FeedbackRepositoryProtocol
|
||||
from store.repositories.models.feedback import Feedback as FeedbackModel
|
||||
|
||||
|
||||
def _to_feedback(m: FeedbackModel) -> Feedback:
|
||||
return Feedback(
|
||||
feedback_id=m.feedback_id,
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
rating=m.rating,
|
||||
user_id=m.user_id,
|
||||
message_id=m.message_id,
|
||||
comment=m.comment,
|
||||
created_time=m.created_time,
|
||||
)
|
||||
|
||||
|
||||
class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
if data.rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||
model = FeedbackModel(
|
||||
feedback_id=data.feedback_id,
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
rating=data.rating,
|
||||
user_id=data.user_id,
|
||||
message_id=data.message_id,
|
||||
comment=data.comment,
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_feedback(model)
|
||||
|
||||
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
if data.rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel).where(
|
||||
FeedbackModel.thread_id == data.thread_id,
|
||||
FeedbackModel.run_id == data.run_id,
|
||||
FeedbackModel.user_id == data.user_id,
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if model is None:
|
||||
return await self.create_feedback(data)
|
||||
|
||||
model.rating = data.rating
|
||||
model.message_id = data.message_id
|
||||
model.comment = data.comment
|
||||
model.created_time = datetime.now(UTC)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_feedback(model)
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
result = await self._session.execute(select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_feedback(model) if model else None
|
||||
|
||||
async def list_feedback_by_run(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
thread_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
stmt = select(FeedbackModel).where(FeedbackModel.run_id == run_id)
|
||||
if thread_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.thread_id == thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||
stmt = stmt.order_by(FeedbackModel.created_time.desc())
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def list_feedback_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
stmt = select(FeedbackModel).where(FeedbackModel.thread_id == thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||
stmt = stmt.order_by(FeedbackModel.created_time.desc())
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||
existing = await self.get_feedback(feedback_id)
|
||||
if existing is None:
|
||||
return False
|
||||
await self._session.execute(delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||
return True
|
||||
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||
stmt = select(FeedbackModel).where(
|
||||
FeedbackModel.thread_id == thread_id,
|
||||
FeedbackModel.run_id == run_id,
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
if model is None:
|
||||
return False
|
||||
await self._session.delete(model)
|
||||
return True
|
||||
|
||||
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
|
||||
stmt = select(
|
||||
func.count().label("total"),
|
||||
func.coalesce(func.sum(case((FeedbackModel.rating == 1, 1), else_=0)), 0).label("positive"),
|
||||
func.coalesce(func.sum(case((FeedbackModel.rating == -1, 1), else_=0)), 0).label("negative"),
|
||||
).where(FeedbackModel.thread_id == thread_id, FeedbackModel.run_id == run_id)
|
||||
row = (await self._session.execute(stmt)).one()
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"total": int(row.total),
|
||||
"positive": int(row.positive),
|
||||
"negative": int(row.negative),
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol
|
||||
from store.repositories.models.run import Run as RunModel
|
||||
|
||||
|
||||
def _to_run(m: RunModel) -> Run:
|
||||
return Run(
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
status=m.status,
|
||||
model_name=m.model_name,
|
||||
multitask_strategy=m.multitask_strategy,
|
||||
error=m.error,
|
||||
follow_up_to_run_id=m.follow_up_to_run_id,
|
||||
metadata=dict(m.meta or {}),
|
||||
kwargs=dict(m.kwargs or {}),
|
||||
total_input_tokens=m.total_input_tokens,
|
||||
total_output_tokens=m.total_output_tokens,
|
||||
total_tokens=m.total_tokens,
|
||||
llm_call_count=m.llm_call_count,
|
||||
lead_agent_tokens=m.lead_agent_tokens,
|
||||
subagent_tokens=m.subagent_tokens,
|
||||
middleware_tokens=m.middleware_tokens,
|
||||
message_count=m.message_count,
|
||||
first_human_message=m.first_human_message,
|
||||
last_ai_message=m.last_ai_message,
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbRunRepository(RunRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_run(self, data: RunCreate) -> Run:
|
||||
model = RunModel(
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
status=data.status,
|
||||
model_name=data.model_name,
|
||||
multitask_strategy=data.multitask_strategy,
|
||||
error=data.error,
|
||||
follow_up_to_run_id=data.follow_up_to_run_id,
|
||||
meta=dict(data.metadata),
|
||||
kwargs=dict(data.kwargs),
|
||||
)
|
||||
if data.created_time is not None:
|
||||
model.created_time = data.created_time
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_run(model)
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
result = await self._session.execute(select(RunModel).where(RunModel.run_id == run_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_run(model) if model else None
|
||||
|
||||
async def list_runs_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[Run]:
|
||||
stmt = select(RunModel).where(RunModel.thread_id == thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunModel.user_id == user_id)
|
||||
stmt = stmt.order_by(RunModel.created_time.desc()).limit(limit).offset(offset)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
|
||||
values: dict = {"status": status}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
|
||||
|
||||
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
|
||||
if before is None:
|
||||
before_dt = datetime.now().astimezone()
|
||||
elif isinstance(before, datetime):
|
||||
before_dt = before
|
||||
else:
|
||||
before_dt = datetime.fromisoformat(before)
|
||||
|
||||
result = await self._session.execute(select(RunModel).where(RunModel.status == "pending", RunModel.created_time <= before_dt).order_by(RunModel.created_time.asc()))
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
values = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
}
|
||||
if first_human_message is not None:
|
||||
values["first_human_message"] = first_human_message[:2000]
|
||||
if last_ai_message is not None:
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = RunModel.status.in_(("success", "error"))
|
||||
model_expr = func.coalesce(RunModel.model_name, "unknown")
|
||||
stmt = (
|
||||
select(
|
||||
model_expr.label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
|
||||
func.coalesce(func.sum(RunModel.total_output_tokens), 0).label("total_output_tokens"),
|
||||
func.coalesce(func.sum(RunModel.lead_agent_tokens), 0).label("lead_agent"),
|
||||
func.coalesce(func.sum(RunModel.subagent_tokens), 0).label("subagent"),
|
||||
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(RunModel.thread_id == thread_id, completed)
|
||||
.group_by(model_expr)
|
||||
)
|
||||
|
||||
rows = (await self._session.execute(stmt)).all()
|
||||
total_tokens = total_input = total_output = total_runs = 0
|
||||
lead_agent = subagent = middleware = 0
|
||||
by_model: dict[str, dict] = {}
|
||||
for row in rows:
|
||||
by_model[row.model] = {"tokens": row.total_tokens, "runs": row.runs}
|
||||
total_tokens += row.total_tokens
|
||||
total_input += row.total_input_tokens
|
||||
total_output += row.total_output_tokens
|
||||
total_runs += row.runs
|
||||
lead_agent += row.lead_agent
|
||||
subagent += row.subagent
|
||||
middleware += row.middleware
|
||||
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_runs": total_runs,
|
||||
"by_model": by_model,
|
||||
"by_caller": {
|
||||
"lead_agent": lead_agent,
|
||||
"subagent": subagent,
|
||||
"middleware": middleware,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
||||
from store.repositories.models.run_event import RunEvent as RunEventModel
|
||||
|
||||
_SEQ_COUNTER_BITS = 12
|
||||
_SEQ_PROCESS_BITS = 9
|
||||
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
|
||||
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
|
||||
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
|
||||
|
||||
|
||||
class _SequenceAllocator:
|
||||
def __init__(self) -> None:
|
||||
self._last_millis = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def allocate_base(self, batch_size: int) -> int:
|
||||
if batch_size >= _SEQ_COUNTER_LIMIT:
|
||||
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
|
||||
|
||||
now_ms = time.time_ns() // 1_000_000
|
||||
with self._lock:
|
||||
seq_ms = max(now_ms, self._last_millis + 1)
|
||||
self._last_millis = seq_ms
|
||||
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
|
||||
|
||||
|
||||
_sequence_allocator = _SequenceAllocator()
|
||||
|
||||
|
||||
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
if not isinstance(content, str):
|
||||
next_metadata = {**metadata, "content_is_json": True}
|
||||
if isinstance(content, dict):
|
||||
next_metadata["content_is_dict"] = True
|
||||
return json.dumps(content, default=str, ensure_ascii=False), next_metadata
|
||||
return content, metadata
|
||||
|
||||
|
||||
def _deserialize_content(content: str, metadata: dict[str, Any]) -> Any:
|
||||
if not (metadata.get("content_is_json") or metadata.get("content_is_dict")):
|
||||
return content
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
|
||||
def _to_run_event(model: RunEventModel) -> RunEvent:
|
||||
raw_metadata = dict(model.meta or {})
|
||||
metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"}
|
||||
return RunEvent(
|
||||
thread_id=model.thread_id,
|
||||
run_id=model.run_id,
|
||||
user_id=model.user_id,
|
||||
event_type=model.event_type,
|
||||
category=model.category,
|
||||
content=_deserialize_content(model.content, raw_metadata),
|
||||
metadata=metadata,
|
||||
seq=model.seq,
|
||||
created_at=model.created_at,
|
||||
)
|
||||
|
||||
|
||||
class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||
if not events:
|
||||
return []
|
||||
|
||||
seq_base = _sequence_allocator.allocate_base(len(events))
|
||||
|
||||
rows: list[RunEventModel] = []
|
||||
|
||||
for index, event in enumerate(events, start=1):
|
||||
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
||||
row = RunEventModel(
|
||||
thread_id=event.thread_id,
|
||||
run_id=event.run_id,
|
||||
user_id=event.user_id,
|
||||
seq=seq_base + index,
|
||||
event_type=event.event_type,
|
||||
category=event.category,
|
||||
content=content,
|
||||
meta=metadata,
|
||||
)
|
||||
if event.created_at is not None:
|
||||
row.created_at = event.created_at
|
||||
self._session.add(row)
|
||||
rows.append(row)
|
||||
|
||||
await self._session.flush()
|
||||
return [_to_run_event(row) for row in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
if event_types is not None:
|
||||
stmt = stmt.where(RunEventModel.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
stmt = select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(stmt)
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
conditions = [RunEventModel.thread_id == thread_id]
|
||||
if user_id is not None:
|
||||
conditions.append(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
|
||||
await self._session.execute(delete(RunEventModel).where(*conditions))
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
|
||||
conditions = [RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id]
|
||||
if user_id is not None:
|
||||
conditions.append(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
|
||||
await self._session.execute(delete(RunEventModel).where(*conditions))
|
||||
return int(count or 0)
|
||||
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.persistence.json_compat import json_match
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
InvalidMetadataFilterError,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||
return ThreadMeta(
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
display_name=m.display_name,
|
||||
status=m.status,
|
||||
metadata=dict(m.meta or {}),
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
model = ThreadMetaModel(
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
display_name=data.display_name,
|
||||
status=data.status,
|
||||
meta=dict(data.metadata),
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_thread_meta(model)
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_thread_meta(model) if model else None
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
values: dict = {}
|
||||
if display_name is not None:
|
||||
values["display_name"] = display_name
|
||||
if status is not None:
|
||||
values["status"] = status
|
||||
if metadata is not None:
|
||||
values["meta"] = dict(metadata)
|
||||
if not values:
|
||||
return
|
||||
await self._session.execute(update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
stmt = select(ThreadMetaModel)
|
||||
|
||||
if status is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.status == status)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.user_id == user_id)
|
||||
if assistant_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||
if metadata:
|
||||
applied = 0
|
||||
for key, value in metadata.items():
|
||||
try:
|
||||
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
|
||||
applied += 1
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
|
||||
if applied == 0:
|
||||
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
|
||||
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
|
||||
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_thread_meta(m) for m in result.scalars().all()]
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.user import User, UserCreate, UserNotFoundError, UserRepositoryProtocol
|
||||
from store.repositories.models.user import User as UserModel
|
||||
|
||||
|
||||
def _to_user(model: UserModel) -> User:
|
||||
return User(
|
||||
id=model.id,
|
||||
email=model.email,
|
||||
password_hash=model.password_hash,
|
||||
system_role=model.system_role, # type: ignore[arg-type]
|
||||
created_at=model.created_at,
|
||||
oauth_provider=model.oauth_provider,
|
||||
oauth_id=model.oauth_id,
|
||||
needs_setup=model.needs_setup,
|
||||
token_version=model.token_version,
|
||||
)
|
||||
|
||||
|
||||
class DbUserRepository(UserRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_user(self, data: UserCreate) -> User:
|
||||
model = UserModel(
|
||||
id=data.id,
|
||||
email=data.email,
|
||||
system_role=data.system_role,
|
||||
password_hash=data.password_hash,
|
||||
oauth_provider=data.oauth_provider,
|
||||
oauth_id=data.oauth_id,
|
||||
needs_setup=data.needs_setup,
|
||||
token_version=data.token_version,
|
||||
)
|
||||
if data.created_at is not None:
|
||||
model.created_at = data.created_at
|
||||
self._session.add(model)
|
||||
try:
|
||||
await self._session.flush()
|
||||
except IntegrityError as exc:
|
||||
await self._session.rollback()
|
||||
raise ValueError(f"Email already registered: {data.email}") from exc
|
||||
await self._session.refresh(model)
|
||||
return _to_user(model)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
model = await self._session.get(UserModel, user_id)
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
result = await self._session.execute(select(UserModel).where(UserModel.email == email))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
result = await self._session.execute(
|
||||
select(UserModel).where(
|
||||
UserModel.oauth_provider == provider,
|
||||
UserModel.oauth_id == oauth_id,
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def get_first_admin(self) -> User | None:
|
||||
result = await self._session.execute(select(UserModel).where(UserModel.system_role == "admin").limit(1))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def update_user(self, data: User) -> User:
|
||||
model = await self._session.get(UserModel, data.id)
|
||||
if model is None:
|
||||
raise UserNotFoundError(f"User {data.id} no longer exists")
|
||||
|
||||
model.email = data.email
|
||||
model.password_hash = data.password_hash
|
||||
model.system_role = data.system_role
|
||||
model.oauth_provider = data.oauth_provider
|
||||
model.oauth_id = data.oauth_id
|
||||
model.needs_setup = data.needs_setup
|
||||
model.token_version = data.token_version
|
||||
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_user(model)
|
||||
|
||||
async def count_users(self) -> int:
|
||||
count = await self._session.scalar(select(func.count()).select_from(UserModel))
|
||||
return int(count or 0)
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
count = await self._session.scalar(select(func.count()).select_from(UserModel).where(UserModel.system_role == "admin"))
|
||||
return int(count or 0)
|
||||
@@ -0,0 +1,36 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories import (
|
||||
FeedbackRepositoryProtocol,
|
||||
RunEventRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
UserRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.db import (
|
||||
DbFeedbackRepository,
|
||||
DbRunEventRepository,
|
||||
DbRunRepository,
|
||||
DbThreadMetaRepository,
|
||||
DbUserRepository,
|
||||
)
|
||||
|
||||
|
||||
def build_thread_meta_repository(session: AsyncSession) -> ThreadMetaRepositoryProtocol:
|
||||
return DbThreadMetaRepository(session)
|
||||
|
||||
|
||||
def build_run_repository(session: AsyncSession) -> RunRepositoryProtocol:
|
||||
return DbRunRepository(session)
|
||||
|
||||
|
||||
def build_feedback_repository(session: AsyncSession) -> FeedbackRepositoryProtocol:
|
||||
return DbFeedbackRepository(session)
|
||||
|
||||
|
||||
def build_run_event_repository(session: AsyncSession) -> RunEventRepositoryProtocol:
|
||||
return DbRunEventRepository(session)
|
||||
|
||||
|
||||
def build_user_repository(session: AsyncSession) -> UserRepositoryProtocol:
|
||||
return DbUserRepository(session)
|
||||
@@ -0,0 +1,7 @@
|
||||
from store.repositories.models.feedback import Feedback
|
||||
from store.repositories.models.run import Run
|
||||
from store.repositories.models.run_event import RunEvent
|
||||
from store.repositories.models.thread_meta import ThreadMeta
|
||||
from store.repositories.models.user import User
|
||||
|
||||
__all__ = ["Feedback", "Run", "RunEvent", "ThreadMeta", "User"]
|
||||
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||
|
||||
|
||||
class Feedback(DataClassBase):
|
||||
"""Feedback table (create-only, no updated_time)."""
|
||||
|
||||
__tablename__ = "feedback"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
||||
{"comment": "Feedback table."},
|
||||
)
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
rating: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
comment: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Index, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||
|
||||
|
||||
class Run(DataClassBase):
|
||||
"""Run metadata table."""
|
||||
|
||||
__tablename__ = "runs"
|
||||
__table_args__ = (
|
||||
Index("ix_runs_thread_status", "thread_id", "status"),
|
||||
{"comment": "Run metadata table."},
|
||||
)
|
||||
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
|
||||
model_name: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
|
||||
error: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
|
||||
kwargs: Mapped[dict[str, Any]] = mapped_column("kwargs_json", JSON, default_factory=dict)
|
||||
|
||||
total_input_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_output_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
llm_call_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
lead_agent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
subagent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
middleware_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
message_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
first_human_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
last_ai_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
"updated_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import (
|
||||
DataClassBase,
|
||||
TimeZone,
|
||||
UniversalText,
|
||||
current_time,
|
||||
id_key,
|
||||
)
|
||||
|
||||
|
||||
class RunEvent(DataClassBase):
|
||||
"""Run event table."""
|
||||
|
||||
__tablename__ = "run_events"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
|
||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
||||
{"comment": "Run event table."},
|
||||
)
|
||||
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), index=True)
|
||||
category: Mapped[str] = mapped_column(String(16), index=True)
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
|
||||
content: Mapped[str] = mapped_column(UniversalText, default="")
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Event timestamp",
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||
|
||||
|
||||
class ThreadMeta(DataClassBase):
|
||||
"""Thread metadata table."""
|
||||
|
||||
__tablename__ = "threads_meta"
|
||||
__table_args__ = {"comment": "Thread metadata table."}
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None, index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256), default=None)
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle", index=True)
|
||||
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
"updated_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Index, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||
|
||||
|
||||
class User(DataClassBase):
|
||||
"""User account table."""
|
||||
|
||||
__tablename__ = "users"
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"idx_users_oauth_identity",
|
||||
"oauth_provider",
|
||||
"oauth_id",
|
||||
unique=True,
|
||||
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
|
||||
),
|
||||
{"comment": "User account table."},
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
|
||||
system_role: Mapped[str] = mapped_column(String(16), default="user")
|
||||
|
||||
password_hash: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
oauth_provider: Mapped[str | None] = mapped_column(String(32), default=None)
|
||||
oauth_id: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
needs_setup: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
token_version: Mapped[int] = mapped_column(default=0)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .timezone import get_timezone
|
||||
|
||||
__all__ = ["get_timezone"]
|
||||
@@ -0,0 +1,51 @@
|
||||
import zoneinfo
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from store.config.app_config import get_app_config
|
||||
|
||||
# IANA identifiers that map to UTC — see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
|
||||
_UTC_IDENTIFIERS = frozenset({"Etc/UCT", "Etc/Universal", "Etc/UTC", "Etc/Zulu", "UCT", "Universal", "UTC", "Zulu"})
|
||||
|
||||
|
||||
class TimeZone:
|
||||
def __init__(self) -> None:
|
||||
app_config = get_app_config()
|
||||
if app_config.timezone in _UTC_IDENTIFIERS:
|
||||
self.tz_info = UTC
|
||||
else:
|
||||
self.tz_info = zoneinfo.ZoneInfo(app_config.timezone)
|
||||
|
||||
def now(self) -> datetime:
|
||||
"""Return the current time in the configured timezone."""
|
||||
return datetime.now(self.tz_info)
|
||||
|
||||
def from_datetime(self, t: datetime) -> datetime:
|
||||
"""Convert a datetime to the configured timezone."""
|
||||
return t.astimezone(self.tz_info)
|
||||
|
||||
def from_str(self, t_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime:
|
||||
"""Parse a time string and attach the configured timezone."""
|
||||
return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info)
|
||||
|
||||
@staticmethod
|
||||
def to_str(t: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""Format a datetime to string."""
|
||||
return t.strftime(format_str)
|
||||
|
||||
@staticmethod
|
||||
def to_utc(t: datetime | int) -> datetime:
|
||||
"""Convert a datetime or Unix timestamp to UTC."""
|
||||
if isinstance(t, datetime):
|
||||
return t.astimezone(UTC)
|
||||
return datetime.fromtimestamp(t, tz=UTC)
|
||||
|
||||
|
||||
_timezone = None
|
||||
|
||||
|
||||
def get_timezone() -> TimeZone:
|
||||
"""Return the global TimeZone singleton (lazy-initialized)."""
|
||||
global _timezone
|
||||
if _timezone is None:
|
||||
_timezone = TimeZone()
|
||||
return _timezone
|
||||
@@ -6,6 +6,7 @@ readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"deerflow-harness",
|
||||
"deerflow-storage",
|
||||
"fastapi>=0.115.0",
|
||||
"httpx>=0.28.0",
|
||||
"python-multipart>=0.0.27",
|
||||
@@ -24,7 +25,8 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
postgres = ["deerflow-harness[postgres]"]
|
||||
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
|
||||
mysql = ["deerflow-storage[mysql]"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
@@ -43,7 +45,8 @@ markers = [
|
||||
index-url = "https://pypi.org/simple"
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["packages/harness"]
|
||||
members = ["packages/harness", "packages/storage"]
|
||||
|
||||
[tool.uv.sources]
|
||||
deerflow-harness = { workspace = true }
|
||||
deerflow-storage = { workspace = true }
|
||||
|
||||
@@ -94,12 +94,15 @@ class TestHarnessPackaging:
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
|
||||
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
|
||||
def test_workspace_pyproject_forwards_postgres_extra_to_storage_packages(self):
|
||||
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
|
||||
data = tomllib.loads(pyproject_path.read_text())
|
||||
|
||||
optional_dependencies = data["project"]["optional-dependencies"]
|
||||
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
|
||||
assert optional_dependencies["postgres"] == [
|
||||
"deerflow-harness[postgres]",
|
||||
"deerflow-storage[postgres]",
|
||||
]
|
||||
|
||||
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
|
||||
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
|
||||
from deerflow.config.memory_config import MemoryConfig
|
||||
@@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None:
|
||||
assert elapsed < 0.1
|
||||
assert finished.is_set() is False
|
||||
assert finished.wait(1.0) is True
|
||||
|
||||
|
||||
def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
||||
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
||||
|
||||
assert queue.pending_count == 2
|
||||
assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"]
|
||||
|
||||
|
||||
def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(
|
||||
thread_id="thread-1",
|
||||
messages=["first"],
|
||||
agent_name="agent-a",
|
||||
correction_detected=True,
|
||||
)
|
||||
queue.add(
|
||||
thread_id="thread-1",
|
||||
messages=["second"],
|
||||
agent_name="agent-a",
|
||||
correction_detected=False,
|
||||
)
|
||||
|
||||
assert queue.pending_count == 1
|
||||
assert queue._queue[0].agent_name == "agent-a"
|
||||
assert queue._queue[0].messages == ["second"]
|
||||
assert queue._queue[0].correction_detected is True
|
||||
|
||||
|
||||
def test_process_queue_updates_different_agents_in_same_thread_separately() -> None:
|
||||
queue = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
|
||||
patch.object(queue, "_reset_timer"),
|
||||
):
|
||||
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
|
||||
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
|
||||
|
||||
mock_updater = MagicMock()
|
||||
mock_updater.update_memory.return_value = True
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater),
|
||||
patch("deerflow.agents.memory.queue.time.sleep"),
|
||||
):
|
||||
queue.flush()
|
||||
|
||||
assert mock_updater.update_memory.call_count == 2
|
||||
mock_updater.update_memory.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
messages=["agent-a"],
|
||||
thread_id="thread-1",
|
||||
agent_name="agent-a",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=False,
|
||||
user_id=None,
|
||||
),
|
||||
call(
|
||||
messages=["agent-b"],
|
||||
thread_id="thread-1",
|
||||
agent_name="agent-b",
|
||||
correction_detected=False,
|
||||
reinforcement_detected=False,
|
||||
user_id=None,
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -38,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater():
|
||||
mock_updater.update_memory.assert_called_once()
|
||||
call_kwargs = mock_updater.update_memory.call_args.kwargs
|
||||
assert call_kwargs["user_id"] == "alice"
|
||||
|
||||
|
||||
def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent():
|
||||
q = MemoryUpdateQueue()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
||||
q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
||||
|
||||
assert q.pending_count == 2
|
||||
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
||||
assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]]
|
||||
|
||||
|
||||
def test_queue_still_coalesces_updates_for_same_user_thread_and_agent():
|
||||
q = MemoryUpdateQueue()
|
||||
|
||||
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
|
||||
q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice")
|
||||
q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice")
|
||||
|
||||
assert q.pending_count == 1
|
||||
assert q._queue[0].messages == ["second"]
|
||||
assert q._queue[0].user_id == "alice"
|
||||
assert q._queue[0].agent_name == "researcher"
|
||||
|
||||
|
||||
def test_add_nowait_keeps_different_users_separate():
|
||||
q = MemoryUpdateQueue()
|
||||
|
||||
with (
|
||||
patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)),
|
||||
patch.object(q, "_schedule_timer"),
|
||||
):
|
||||
q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
|
||||
q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
|
||||
|
||||
assert q.pending_count == 2
|
||||
assert [context.user_id for context in q._queue] == ["alice", "bob"]
|
||||
|
||||
@@ -268,6 +268,39 @@ class TestEdgeCases:
|
||||
class TestDbRunEventStore:
|
||||
"""Tests for DbRunEventStore with temp SQLite."""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self):
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
from deerflow.runtime.events.store.db import DbRunEventStore
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self):
|
||||
self.dialect = postgresql.dialect()
|
||||
self.execute_calls = []
|
||||
self.scalar_stmt = None
|
||||
|
||||
def get_bind(self):
|
||||
return self
|
||||
|
||||
async def execute(self, stmt, params=None):
|
||||
self.execute_calls.append((stmt, params))
|
||||
|
||||
async def scalar(self, stmt):
|
||||
self.scalar_stmt = stmt
|
||||
return 41
|
||||
|
||||
session = FakeSession()
|
||||
|
||||
max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1")
|
||||
|
||||
assert max_seq == 41
|
||||
assert session.execute_calls
|
||||
assert session.execute_calls[0][1] == {"thread_id": "thread-1"}
|
||||
assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0])
|
||||
compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect()))
|
||||
assert "FOR UPDATE" not in compiled
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_basic_crud(self, tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, MetaData, String, Table
|
||||
from sqlalchemy.dialects import mysql, postgresql
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.persistence.json_compat import json_match
|
||||
|
||||
|
||||
def _table():
|
||||
metadata = MetaData()
|
||||
return Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_sqlite() -> None:
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
table = _table()
|
||||
dialect = create_engine("sqlite://").dialect
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'null'")
|
||||
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'true'")
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "= 'integer'" in int_sql
|
||||
assert "CAST" in int_sql
|
||||
|
||||
float_sql = str(json_match(table.c.data, "k", 3.14).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "IN ('integer', 'real')" in float_sql
|
||||
assert "REAL" in float_sql
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_postgres() -> None:
|
||||
table = _table()
|
||||
dialect = postgresql.dialect()
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_typeof(t.data -> 'k') = 'null'")
|
||||
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')")
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "CASE WHEN" in int_sql
|
||||
assert "BIGINT" in int_sql
|
||||
assert "'^-?[0-9]+$'" in int_sql
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_mysql() -> None:
|
||||
table = _table()
|
||||
dialect = mysql.dialect()
|
||||
|
||||
null_sql = str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert null_sql == "JSON_TYPE(JSON_EXTRACT(t.data, '$.\"k\"')) = 'NULL'"
|
||||
|
||||
bool_sql = str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "JSON_TYPE(JSON_EXTRACT" in bool_sql
|
||||
assert "= 'BOOLEAN'" in bool_sql
|
||||
assert "= 'true'" in bool_sql
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "= 'INTEGER'" in int_sql
|
||||
assert "SIGNED" in int_sql
|
||||
|
||||
|
||||
def test_storage_json_match_rejects_unsafe_keys_and_values() -> None:
|
||||
table = _table()
|
||||
|
||||
for bad_key in ["a.b", "bad;key", "with space", "", 42, None]:
|
||||
with pytest.raises(ValueError, match="JsonMatch key must match"):
|
||||
json_match(table.c.data, bad_key, "x") # type: ignore[arg-type]
|
||||
|
||||
for bad_value in [[], {}, object()]:
|
||||
with pytest.raises(TypeError, match="JsonMatch value must be"):
|
||||
json_match(table.c.data, "k", bad_value)
|
||||
|
||||
with pytest.raises(TypeError, match="out of signed 64-bit range"):
|
||||
json_match(table.c.data, "k", 2**63)
|
||||
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.config.storage_config import StorageConfig
|
||||
from store.persistence.factory import _create_database_url, storage_config_from_database_config
|
||||
|
||||
|
||||
def test_database_sqlite_config_maps_to_storage_config(tmp_path):
|
||||
database = SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=True,
|
||||
pool_size=9,
|
||||
)
|
||||
|
||||
storage = storage_config_from_database_config(database)
|
||||
|
||||
assert storage == StorageConfig(
|
||||
driver="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=True,
|
||||
pool_size=9,
|
||||
)
|
||||
assert storage.sqlite_storage_path == str(tmp_path / "deerflow.db")
|
||||
|
||||
|
||||
def test_database_memory_config_is_not_a_storage_backend():
|
||||
database = SimpleNamespace(backend="memory")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_database_postgres_config_preserves_url_and_pool_options():
|
||||
database = SimpleNamespace(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql://user:pass@db.example:5544/deerflow",
|
||||
echo_sql=True,
|
||||
pool_size=11,
|
||||
)
|
||||
|
||||
storage = storage_config_from_database_config(database)
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert storage.driver == "postgres"
|
||||
assert storage.database_url == "postgresql://user:pass@db.example:5544/deerflow"
|
||||
assert storage.username == "user"
|
||||
assert storage.password == "pass"
|
||||
assert storage.host == "db.example"
|
||||
assert storage.port == 5544
|
||||
assert storage.db_name == "deerflow"
|
||||
assert storage.echo_sql is True
|
||||
assert storage.pool_size == 11
|
||||
assert url.drivername == "postgresql+asyncpg"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_mysql_database_url_is_normalized_to_async_driver():
|
||||
storage = StorageConfig(
|
||||
driver="mysql",
|
||||
database_url="mysql://user:pass@db.example:3306/deerflow",
|
||||
)
|
||||
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert url.drivername == "mysql+aiomysql"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_mysql_async_database_url_is_preserved():
|
||||
storage = StorageConfig(
|
||||
driver="mysql",
|
||||
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
|
||||
)
|
||||
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert url.drivername == "mysql+asyncmy"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_database_postgres_requires_url():
|
||||
database = SimpleNamespace(backend="postgres", postgres_url="")
|
||||
|
||||
with pytest.raises(ValueError, match="database.postgres_url is required"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_unsupported_database_backend_rejected():
|
||||
database = SimpleNamespace(backend="oracle")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_storage_models_import_without_config_file(tmp_path):
|
||||
env = os.environ.copy()
|
||||
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
"from store.persistence.base_model import UniversalText, id_key; from store.repositories.models import RunEvent; print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
|
||||
],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
env=env,
|
||||
text=True,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, result.stderr
|
||||
assert "UniversalText run_events" in result.stdout
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import UserCreate, build_user_repository
|
||||
|
||||
|
||||
def test_sqlite_persistence_from_database_config_creates_storage_tables(tmp_path):
|
||||
async def run() -> None:
|
||||
persistence = await create_persistence_from_database_config(
|
||||
SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=False,
|
||||
pool_size=5,
|
||||
)
|
||||
)
|
||||
assert persistence is not None
|
||||
try:
|
||||
await persistence.setup()
|
||||
|
||||
async with persistence.engine.connect() as conn:
|
||||
tables = await conn.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names()))
|
||||
|
||||
assert {
|
||||
"users",
|
||||
"runs",
|
||||
"run_events",
|
||||
"threads_meta",
|
||||
"feedback",
|
||||
}.issubset(tables)
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email="storage-user@example.com",
|
||||
password_hash="hash",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
assert await repo.get_user_by_id(user.id) == user
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
asyncio.run(run())
|
||||
@@ -0,0 +1,395 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import (
|
||||
FeedbackCreate,
|
||||
InvalidMetadataFilterError,
|
||||
RunCreate,
|
||||
RunEventCreate,
|
||||
ThreadMetaCreate,
|
||||
build_feedback_repository,
|
||||
build_run_event_repository,
|
||||
build_run_repository,
|
||||
build_thread_meta_repository,
|
||||
)
|
||||
|
||||
|
||||
async def _make_persistence(tmp_path):
|
||||
persistence = await create_persistence_from_database_config(
|
||||
SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=False,
|
||||
pool_size=5,
|
||||
)
|
||||
)
|
||||
await persistence.setup()
|
||||
return persistence
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_run_repository_filters_and_aggregates(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
old = datetime.now(UTC) - timedelta(hours=1)
|
||||
newer = datetime.now(UTC)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id="run-old",
|
||||
thread_id="thread-1",
|
||||
user_id="alice",
|
||||
status="pending",
|
||||
model_name="model-a",
|
||||
metadata={"kind": "draft"},
|
||||
kwargs={"temperature": 0.2},
|
||||
created_time=old,
|
||||
)
|
||||
)
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id="run-new",
|
||||
thread_id="thread-1",
|
||||
user_id="bob",
|
||||
status="running",
|
||||
model_name="model-b",
|
||||
error="queued",
|
||||
created_time=newer,
|
||||
)
|
||||
)
|
||||
await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running"))
|
||||
await repo.update_run_completion(
|
||||
"run-old",
|
||||
status="success",
|
||||
total_input_tokens=7,
|
||||
total_output_tokens=3,
|
||||
total_tokens=10,
|
||||
llm_call_count=1,
|
||||
lead_agent_tokens=8,
|
||||
subagent_tokens=2,
|
||||
first_human_message="hello",
|
||||
last_ai_message="world",
|
||||
)
|
||||
await repo.update_run_completion(
|
||||
"run-new",
|
||||
status="error",
|
||||
total_tokens=5,
|
||||
middleware_tokens=5,
|
||||
error="failed",
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
fetched = await repo.get_run("run-old")
|
||||
assert fetched is not None
|
||||
assert fetched.metadata == {"kind": "draft"}
|
||||
assert fetched.kwargs == {"temperature": 0.2}
|
||||
assert fetched.first_human_message == "hello"
|
||||
assert fetched.last_ai_message == "world"
|
||||
|
||||
all_thread_runs = await repo.list_runs_by_thread("thread-1")
|
||||
assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"]
|
||||
alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice")
|
||||
assert [run.run_id for run in alice_runs] == ["run-old"]
|
||||
|
||||
pending = await repo.list_pending(before=datetime.now(UTC).isoformat())
|
||||
assert [run.run_id for run in pending] == []
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("thread-1")
|
||||
assert agg["total_tokens"] == 15
|
||||
assert agg["total_input_tokens"] == 7
|
||||
assert agg["total_output_tokens"] == 3
|
||||
assert agg["total_runs"] == 2
|
||||
assert agg["by_model"] == {
|
||||
"model-a": {"tokens": 10, "runs": 1},
|
||||
"model-b": {"tokens": 5, "runs": 1},
|
||||
}
|
||||
assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5}
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id="thread-1",
|
||||
assistant_id="agent-a",
|
||||
user_id="alice",
|
||||
display_name="Initial",
|
||||
status="idle",
|
||||
metadata={"topic": "finance", "region": "cn"},
|
||||
)
|
||||
)
|
||||
await repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id="thread-2",
|
||||
assistant_id="agent-b",
|
||||
user_id="bob",
|
||||
status="running",
|
||||
metadata={"topic": "legal"},
|
||||
)
|
||||
)
|
||||
await repo.update_thread_meta(
|
||||
"thread-1",
|
||||
display_name="Updated",
|
||||
status="running",
|
||||
metadata={"topic": "finance", "region": "us"},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
fetched = await repo.get_thread_meta("thread-1")
|
||||
assert fetched is not None
|
||||
assert fetched.display_name == "Updated"
|
||||
assert fetched.status == "running"
|
||||
assert fetched.metadata == {"topic": "finance", "region": "us"}
|
||||
|
||||
by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice")
|
||||
assert [thread.thread_id for thread in by_metadata] == ["thread-1"]
|
||||
by_assistant = await repo.search_threads(assistant_id="agent-b")
|
||||
assert [thread.thread_id for thread in by_assistant] == ["thread-2"]
|
||||
|
||||
await repo.delete_thread("thread-1")
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
assert await repo.get_thread_meta("thread-1") is None
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filters_are_type_safe(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-true", metadata={"value": True}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-false", metadata={"value": False}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="int-one", metadata={"value": 1}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="null-value", metadata={"value": None}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="missing-value", metadata={"other": "x"}))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": True})] == ["bool-true"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": False})] == ["bool-false"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": 1})] == ["int-one"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": None})] == ["null-value"]
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filters_paginate_after_sql_match(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
for index in range(30):
|
||||
metadata = {"target": "yes"} if index % 3 == 0 else {"target": "no"}
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id=f"thread-{index:02d}", metadata=metadata))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
first_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=0)
|
||||
second_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=3)
|
||||
last_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=9)
|
||||
|
||||
assert len(first_page) == 3
|
||||
assert len(second_page) == 3
|
||||
assert len(last_page) == 1
|
||||
assert {row.thread_id for row in first_page}.isdisjoint({row.thread_id for row in second_page})
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filter_rejects_invalid_entries(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-1", metadata={"env": "prod"}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-2", metadata={"env": "staging"}))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
partial = await repo.search_threads(metadata={"env": "prod", "bad;key": "ignored"})
|
||||
assert [row.thread_id for row in partial] == ["thread-1"]
|
||||
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search_threads(metadata={"bad;key": "x"})
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search_threads(metadata={"env": ["prod", "staging"]})
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
first = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-1",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=1,
|
||||
user_id="alice",
|
||||
message_id="msg-1",
|
||||
comment="good",
|
||||
)
|
||||
)
|
||||
second = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-2",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=-1,
|
||||
user_id="bob",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
assert await repo.get_feedback(first.feedback_id) == first
|
||||
assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [
|
||||
second.feedback_id,
|
||||
first.feedback_id,
|
||||
]
|
||||
assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == {
|
||||
"fb-1",
|
||||
"fb-2",
|
||||
}
|
||||
assert await repo.delete_feedback("fb-1") is True
|
||||
assert await repo.delete_feedback("missing") is False
|
||||
with pytest.raises(ValueError, match="rating must be"):
|
||||
await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-bad",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=0,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
assert await repo.get_feedback("fb-1") is None
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.append_batch(
|
||||
[
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
user_id="alice",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content={"role": "user", "content": "hello"},
|
||||
metadata={"source": "input"},
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
event_type="tool",
|
||||
category="debug",
|
||||
content="tool-call",
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-2",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="second",
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-2",
|
||||
run_id="run-3",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="other-thread",
|
||||
),
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
|
||||
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
|
||||
assert rows[1].seq == rows[0].seq + 1
|
||||
assert rows[2].seq == rows[1].seq + 1
|
||||
assert rows[0].content == {"role": "user", "content": "hello"}
|
||||
assert rows[0].metadata == {"source": "input", "content_is_json": True}
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
messages = await repo.list_messages("thread-1", limit=2)
|
||||
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
|
||||
assert await repo.count_messages("thread-1") == 2
|
||||
|
||||
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
|
||||
assert [event.seq for event in after] == [rows[0].seq]
|
||||
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
|
||||
assert [event.seq for event in before] == [rows[0].seq]
|
||||
|
||||
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
|
||||
assert [event.content for event in events] == ["tool-call"]
|
||||
|
||||
assert await repo.delete_by_run("thread-1", "run-1") == 2
|
||||
assert await repo.delete_by_thread("thread-2") == 1
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
remaining = await repo.list_events("thread-1", "run-2")
|
||||
assert [event.seq for event in remaining] == [rows[2].seq]
|
||||
assert await repo.count_messages("thread-2") == 0
|
||||
|
||||
later = await repo.append_batch(
|
||||
[
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-4",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="after-delete",
|
||||
)
|
||||
]
|
||||
)
|
||||
assert later[0].seq > rows[2].seq
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
@@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from store.repositories import UserCreate, UserNotFoundError, build_user_repository
|
||||
from store.repositories.models import User as UserModel
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]:
|
||||
db_path = tmp_path / "storage-users.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(UserModel.metadata.create_all)
|
||||
|
||||
try:
|
||||
yield async_sessionmaker(engine, expire_on_commit=False)
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def _create_user(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
*,
|
||||
email: str = "user@example.com",
|
||||
system_role: str = "user",
|
||||
oauth_provider: str | None = None,
|
||||
oauth_id: str | None = None,
|
||||
):
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email=email,
|
||||
password_hash="hash",
|
||||
system_role=system_role, # type: ignore[arg-type]
|
||||
oauth_provider=oauth_provider,
|
||||
oauth_id=oauth_id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def test_create_and_get_user_by_id_and_email(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
created = await _create_user(session_factory)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
by_id = await repo.get_user_by_id(created.id)
|
||||
by_email = await repo.get_user_by_email(created.email)
|
||||
|
||||
assert by_id == created
|
||||
assert by_email == created
|
||||
assert created.system_role == "user"
|
||||
assert created.needs_setup is False
|
||||
assert created.token_version == 0
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_duplicate_email_raises_value_error(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="dupe@example.com")
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
with pytest.raises(ValueError, match="Email already registered"):
|
||||
await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email="dupe@example.com",
|
||||
password_hash="hash",
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_oauth_lookup_and_plain_users_without_oauth(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="local-1@example.com")
|
||||
await _create_user(session_factory, email="local-2@example.com")
|
||||
oauth_user = await _create_user(
|
||||
session_factory,
|
||||
email="oauth@example.com",
|
||||
oauth_provider="github",
|
||||
oauth_id="gh-123",
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
assert await repo.count_users() == 3
|
||||
assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user
|
||||
assert await repo.get_user_by_oauth("github", "missing") is None
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_count_admins_and_get_first_admin(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="user@example.com")
|
||||
admin = await _create_user(
|
||||
session_factory,
|
||||
email="admin@example.com",
|
||||
system_role="admin",
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
assert await repo.count_users() == 2
|
||||
assert await repo.count_admin_users() == 1
|
||||
assert await repo.get_first_admin() == admin
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_update_user_round_trips_token_version_and_setup_state(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
created = await _create_user(session_factory)
|
||||
updated = created.model_copy(
|
||||
update={
|
||||
"email": "renamed@example.com",
|
||||
"token_version": 4,
|
||||
"needs_setup": True,
|
||||
}
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
saved = await repo.update_user(updated)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
fetched = await repo.get_user_by_id(created.id)
|
||||
|
||||
assert saved.email == "renamed@example.com"
|
||||
assert fetched == updated
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_update_missing_user_raises(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
missing = UserCreate(id=str(uuid4()), email="missing@example.com")
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
created_shape = await repo.create_user(missing)
|
||||
await session.rollback()
|
||||
|
||||
with pytest.raises(UserNotFoundError):
|
||||
await repo.update_user(created_shape)
|
||||
|
||||
asyncio.run(run())
|
||||
@@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
|
||||
)
|
||||
|
||||
|
||||
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
|
||||
def _runtime(
|
||||
thread_id: str | None = "thread-1",
|
||||
agent_name: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> SimpleNamespace:
|
||||
context = {}
|
||||
if thread_id is not None:
|
||||
context["thread_id"] = thread_id
|
||||
if agent_name is not None:
|
||||
context["agent_name"] = agent_name
|
||||
if user_id is not None:
|
||||
context["user_id"] = user_id
|
||||
return SimpleNamespace(context=context)
|
||||
|
||||
|
||||
@@ -634,3 +640,22 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||
|
||||
|
||||
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
|
||||
|
||||
memory_flush_hook(
|
||||
SummarizationEvent(
|
||||
messages_to_summarize=tuple(_messages()[:2]),
|
||||
preserved_messages=(),
|
||||
thread_id="main",
|
||||
agent_name="researcher",
|
||||
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
|
||||
)
|
||||
)
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
||||
|
||||
@@ -1214,11 +1214,12 @@ def test_terminal_event_usage_none_when_no_records(monkeypatch):
|
||||
assert completed[0]["usage"] is None
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
|
||||
@pytest.mark.parametrize("error", [FileNotFoundError("missing config"), ValueError("invalid config")])
|
||||
def test_subagent_usage_cache_is_skipped_when_default_config_cannot_load(monkeypatch, error):
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_app_config",
|
||||
MagicMock(side_effect=FileNotFoundError("missing config")),
|
||||
MagicMock(side_effect=error),
|
||||
)
|
||||
|
||||
assert task_tool_module._token_usage_cache_enabled(None) is False
|
||||
|
||||
Generated
+93
-1
@@ -14,6 +14,7 @@ resolution-markers = [
|
||||
members = [
|
||||
"deer-flow",
|
||||
"deerflow-harness",
|
||||
"deerflow-storage",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -136,6 +137,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441, upload-time = "2026-03-31T22:00:12.791Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiomysql"
|
||||
version = "0.3.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pymysql" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/29/e0/302aeffe8d90853556f47f3106b89c16cc2ec2a4d269bdfd82e3f4ae12cc/aiomysql-0.3.2.tar.gz", hash = "sha256:72d15ef5cfc34c03468eb41e1b90adb9fd9347b0b589114bd23ead569a02ac1a", size = 108311, upload-time = "2025-10-22T00:15:21.278Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/af/aae0153c3e28712adaf462328f6c7a3c196a1c1c27b491de4377dd3e6b52/aiomysql-0.3.2-py3-none-any.whl", hash = "sha256:c82c5ba04137d7afd5c693a258bea8ead2aad77101668044143a991e04632eb2", size = 71834, upload-time = "2025-10-22T00:15:15.905Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosignal"
|
||||
version = "1.4.0"
|
||||
@@ -746,6 +759,7 @@ source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "bcrypt" },
|
||||
{ name = "deerflow-harness" },
|
||||
{ name = "deerflow-storage" },
|
||||
{ name = "dingtalk-stream" },
|
||||
{ name = "email-validator" },
|
||||
{ name = "fastapi" },
|
||||
@@ -763,8 +777,12 @@ dependencies = [
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
mysql = [
|
||||
{ name = "deerflow-storage", extra = ["mysql"] },
|
||||
]
|
||||
postgres = [
|
||||
{ name = "deerflow-harness", extra = ["postgres"] },
|
||||
{ name = "deerflow-storage", extra = ["postgres"] },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
@@ -780,6 +798,9 @@ requires-dist = [
|
||||
{ name = "bcrypt", specifier = ">=4.0.0" },
|
||||
{ name = "deerflow-harness", editable = "packages/harness" },
|
||||
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
||||
{ name = "deerflow-storage", editable = "packages/storage" },
|
||||
{ name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" },
|
||||
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
|
||||
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
||||
{ name = "email-validator", specifier = ">=2.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.115.0" },
|
||||
@@ -795,7 +816,7 @@ requires-dist = [
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
||||
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
||||
]
|
||||
provides-extras = ["postgres"]
|
||||
provides-extras = ["postgres", "mysql"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
@@ -901,6 +922,54 @@ requires-dist = [
|
||||
]
|
||||
provides-extras = ["ollama", "postgres", "pymupdf"]
|
||||
|
||||
[[package]]
|
||||
name = "deerflow-storage"
|
||||
version = "0.1.0"
|
||||
source = { editable = "packages/storage" }
|
||||
dependencies = [
|
||||
{ name = "alembic" },
|
||||
{ name = "dotenv" },
|
||||
{ name = "langgraph" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
mysql = [
|
||||
{ name = "aiomysql" },
|
||||
{ name = "langgraph-checkpoint-mysql" },
|
||||
]
|
||||
postgres = [
|
||||
{ name = "asyncpg" },
|
||||
{ name = "langgraph-checkpoint-postgres" },
|
||||
{ name = "psycopg", extra = ["binary"] },
|
||||
{ name = "psycopg-pool" },
|
||||
]
|
||||
sqlite = [
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "langgraph-checkpoint-sqlite" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiomysql", marker = "extra == 'mysql'", specifier = ">=0.2" },
|
||||
{ name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.22.1" },
|
||||
{ name = "alembic", specifier = ">=1.13" },
|
||||
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
|
||||
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||
{ name = "langgraph", specifier = ">=1.1.9" },
|
||||
{ name = "langgraph-checkpoint-mysql", marker = "extra == 'mysql'", specifier = ">=3.0.0" },
|
||||
{ name = "langgraph-checkpoint-postgres", marker = "extra == 'postgres'", specifier = ">=3.0.5" },
|
||||
{ name = "langgraph-checkpoint-sqlite", marker = "extra == 'sqlite'", specifier = ">=3.0.3" },
|
||||
{ name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.3.3" },
|
||||
{ name = "psycopg-pool", marker = "extra == 'postgres'", specifier = ">=3.3.0" },
|
||||
{ name = "pydantic", specifier = ">=2.12.5" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.3" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<3.0" },
|
||||
]
|
||||
provides-extras = ["postgres", "mysql", "sqlite"]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
@@ -1914,6 +1983,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/5a/6dba29dd89b0a46ae21c707da0f9d17e94f27d3e481ed15bc99d6bd20aa6/langgraph_checkpoint-4.0.2-py3-none-any.whl", hash = "sha256:59b0f29216128a629c58dd07c98aa004f82f51805d5573126ffb419b753ff253", size = 51000, upload-time = "2026-04-15T21:02:59.096Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint-mysql"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langgraph-checkpoint" },
|
||||
{ name = "orjson" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e0/4e/0a6c78e5d3f2ca1525903c2363e721873594b6b77dd83537a6369193c474/langgraph_checkpoint_mysql-3.0.0.tar.gz", hash = "sha256:006aaa089f4c2fbd7b2c113b800ccd3dbb95f92203e656451677256b4b4f880f", size = 213142, upload-time = "2026-01-23T11:11:15.74Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/08/68/343103a7fae05523f9cecabbec2babdb737e66b4bf6ea48ae00c685ed11c/langgraph_checkpoint_mysql-3.0.0-py3-none-any.whl", hash = "sha256:7560ccd16e7596a047e15a307cec12dbd88fdcaab45a75759e5c6adef22a27d1", size = 38009, upload-time = "2026-01-23T11:11:14.697Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint-postgres"
|
||||
version = "3.0.5"
|
||||
@@ -3442,6 +3525,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pymysql"
|
||||
version = "1.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7f/ec/8d45c920e90445f0b75c590b32851853ed319763b0d8dff8d283052da8cf/pymysql-1.1.3.tar.gz", hash = "sha256:e70ebf2047a4edf6138cf79c68ad418ef620af65900aa585c5e8bfc95044d43a", size = 48207, upload-time = "2026-05-01T09:09:54.532Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/dc/9085f3d6f497e9b25fb40d6e8ecef3ddbb5cf977a949b933624a299f5c16/pymysql-1.1.3-py3-none-any.whl", hash = "sha256:8164ba62c552f6105f3b11753352d0f16b90d1703ba67d81923d5a8a5d1c5289", size = 45356, upload-time = "2026-05-01T09:09:53.316Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pypdfium2"
|
||||
version = "5.7.1"
|
||||
|
||||
Reference in New Issue
Block a user