Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d7a2fff7e0 | |||
| eabd78ce4e | |||
| 533d3fbfee | |||
| d6b3a277a5 | |||
| def2a3ad79 | |||
| 3c0b42d836 | |||
| 34ec205e1d | |||
| 11a9041b65 | |||
| d3066a1746 | |||
| 485f8a2bf2 |
@@ -0,0 +1,401 @@
|
|||||||
|
# Storage Package Design
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
DeerFlow currently has several persistence responsibilities spread across app, gateway, runtime, and legacy persistence modules. This makes the persistence boundary difficult to reason about and creates several migration risks:
|
||||||
|
|
||||||
|
- Routers and runtime services can accidentally depend on concrete persistence implementations instead of stable contracts.
|
||||||
|
- User/auth, run metadata, thread metadata, feedback, run events, and checkpointer setup are initialized through different paths.
|
||||||
|
- Some persistence behavior is duplicated between memory, SQLite, and PostgreSQL-oriented code paths.
|
||||||
|
- Incremental migration is hard because app-level code and storage-level code are coupled.
|
||||||
|
- Adding or validating another SQL backend requires touching app/runtime code instead of a storage-owned package.
|
||||||
|
|
||||||
|
The storage package is introduced to make application data persistence a package-level capability with explicit contracts, a clear boundary, and SQL backend compatibility.
|
||||||
|
|
||||||
|
## Goals
|
||||||
|
|
||||||
|
- Provide a standalone `packages/storage` package for durable application data.
|
||||||
|
- Support SQLite, PostgreSQL, and MySQL through a shared persistence construction flow.
|
||||||
|
- Keep LangGraph checkpointer initialization compatible with the same database backend.
|
||||||
|
- Expose repository contracts as the only package-level data access boundary.
|
||||||
|
- Let the app layer depend on app-owned adapters under `app.infra.storage`, not on storage DB implementation classes.
|
||||||
|
- Allow the app/gateway migration to happen in small steps without forcing a large rewrite.
|
||||||
|
|
||||||
|
## Non-Goals
|
||||||
|
|
||||||
|
- This design does not remove legacy persistence in the first PR.
|
||||||
|
- This design does not move routers directly onto storage package models.
|
||||||
|
- This design does not make app routers own SQLAlchemy sessions.
|
||||||
|
- Cron persistence is intentionally out of scope for the storage package foundation.
|
||||||
|
- Memory backend is not part of the durable storage package. Memory compatibility, if still needed by app runtime, belongs outside `packages/storage`.
|
||||||
|
|
||||||
|
## Storage Design Principles
|
||||||
|
|
||||||
|
### Package-Owned Durable Storage
|
||||||
|
|
||||||
|
`packages/storage` owns durable application data persistence. It defines:
|
||||||
|
|
||||||
|
- configuration shape for storage-backed persistence
|
||||||
|
- SQLAlchemy models
|
||||||
|
- repository contracts and DTOs
|
||||||
|
- SQL repository implementations
|
||||||
|
- persistence factory functions
|
||||||
|
- compatibility helpers for config-driven initialization
|
||||||
|
|
||||||
|
The package should be usable without importing `app.gateway`, routers, auth providers, or runtime-specific gateway objects.
|
||||||
|
|
||||||
|
### SQL Backend Compatibility
|
||||||
|
|
||||||
|
The package supports three SQL backends:
|
||||||
|
|
||||||
|
- SQLite for local/single-node deployments
|
||||||
|
- PostgreSQL for production multi-node deployments
|
||||||
|
- MySQL for deployments that standardize on MySQL
|
||||||
|
|
||||||
|
Backend-specific differences are handled inside the storage package:
|
||||||
|
|
||||||
|
- SQLAlchemy async engine URL construction
|
||||||
|
- LangGraph checkpointer connection-string compatibility
|
||||||
|
- JSON metadata filtering across SQLite/PostgreSQL/MySQL
|
||||||
|
- SQL dialect behavior around locking, aggregation, and JSON type semantics
|
||||||
|
|
||||||
|
### Unified Persistence Bundle
|
||||||
|
|
||||||
|
Storage initialization returns an `AppPersistence` bundle:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AppPersistence:
|
||||||
|
checkpointer: Checkpointer
|
||||||
|
engine: AsyncEngine
|
||||||
|
session_factory: async_sessionmaker[AsyncSession]
|
||||||
|
setup: Callable[[], Awaitable[None]]
|
||||||
|
aclose: Callable[[], Awaitable[None]]
|
||||||
|
```
|
||||||
|
|
||||||
|
The app runtime can initialize persistence once, call `setup()`, and then inject:
|
||||||
|
|
||||||
|
- `checkpointer`
|
||||||
|
- `session_factory`
|
||||||
|
- repository adapters
|
||||||
|
|
||||||
|
This keeps checkpointer and application data aligned to the same backend without requiring routers to understand database configuration.
|
||||||
|
|
||||||
|
## Package Layout
|
||||||
|
|
||||||
|
```text
|
||||||
|
backend/packages/storage/
|
||||||
|
store/
|
||||||
|
config/
|
||||||
|
storage_config.py
|
||||||
|
app_config.py
|
||||||
|
persistence/
|
||||||
|
factory.py
|
||||||
|
types.py
|
||||||
|
base_model.py
|
||||||
|
json_compat.py
|
||||||
|
drivers/
|
||||||
|
sqlite.py
|
||||||
|
postgres.py
|
||||||
|
mysql.py
|
||||||
|
repositories/
|
||||||
|
contracts/
|
||||||
|
user.py
|
||||||
|
run.py
|
||||||
|
thread_meta.py
|
||||||
|
feedback.py
|
||||||
|
run_event.py
|
||||||
|
models/
|
||||||
|
user.py
|
||||||
|
run.py
|
||||||
|
thread_meta.py
|
||||||
|
feedback.py
|
||||||
|
run_event.py
|
||||||
|
db/
|
||||||
|
user.py
|
||||||
|
run.py
|
||||||
|
thread_meta.py
|
||||||
|
feedback.py
|
||||||
|
run_event.py
|
||||||
|
factory.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Persistence Construction
|
||||||
|
|
||||||
|
The primary storage entrypoint is:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.persistence import create_persistence_from_storage_config
|
||||||
|
|
||||||
|
persistence = await create_persistence_from_storage_config(storage_config)
|
||||||
|
await persistence.setup()
|
||||||
|
```
|
||||||
|
|
||||||
|
For app-level compatibility with existing database config shape:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.persistence import create_persistence_from_database_config
|
||||||
|
|
||||||
|
persistence = await create_persistence_from_database_config(config.database)
|
||||||
|
await persistence.setup()
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected app startup flow:
|
||||||
|
|
||||||
|
```python
|
||||||
|
persistence = await create_persistence_from_database_config(config.database)
|
||||||
|
await persistence.setup()
|
||||||
|
|
||||||
|
app.state.persistence = persistence
|
||||||
|
app.state.checkpointer = persistence.checkpointer
|
||||||
|
app.state.session_factory = persistence.session_factory
|
||||||
|
```
|
||||||
|
|
||||||
|
Expected app shutdown flow:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await app.state.persistence.aclose()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Repository Contract Design
|
||||||
|
|
||||||
|
Repository contracts are the storage package's public data access boundary. They live under `store.repositories.contracts` and are re-exported from `store.repositories`.
|
||||||
|
|
||||||
|
The key contract groups are:
|
||||||
|
|
||||||
|
- `UserRepositoryProtocol`
|
||||||
|
- `RunRepositoryProtocol`
|
||||||
|
- `ThreadMetaRepositoryProtocol`
|
||||||
|
- `FeedbackRepositoryProtocol`
|
||||||
|
- `RunEventRepositoryProtocol`
|
||||||
|
|
||||||
|
Each contract owns:
|
||||||
|
|
||||||
|
- input DTOs, such as `UserCreate`, `RunCreate`, `ThreadMetaCreate`
|
||||||
|
- output DTOs, such as `User`, `Run`, `ThreadMeta`
|
||||||
|
- repository protocol methods
|
||||||
|
- domain-specific exceptions when needed, such as `InvalidMetadataFilterError`
|
||||||
|
|
||||||
|
Repository construction is session-based:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.repositories import build_run_repository
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_run_repository(session)
|
||||||
|
run = await repo.get_run(run_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
This keeps transaction ownership explicit. The storage package does not hide commits or session lifecycle inside global singletons.
|
||||||
|
|
||||||
|
## App/Infra Calling Contract
|
||||||
|
|
||||||
|
The app layer should not call `store.repositories.db.*` directly. The intended app boundary is `app.infra.storage`.
|
||||||
|
|
||||||
|
`app.infra.storage` is responsible for:
|
||||||
|
|
||||||
|
- receiving `session_factory` from FastAPI runtime initialization
|
||||||
|
- owning session lifecycle for app-facing repository methods
|
||||||
|
- translating storage DTOs to app/gateway DTOs only when needed
|
||||||
|
- preserving the existing app-facing names during migration
|
||||||
|
- depending on storage repository protocols, not concrete DB classes
|
||||||
|
|
||||||
|
Expected adapter pattern:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class StorageRunRepository(RunRepositoryProtocol):
|
||||||
|
def __init__(self, session_factory):
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def get_run(self, run_id: str):
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_repository(session)
|
||||||
|
return await repo.get_run(run_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
For gateway compatibility, app state can keep existing names while the implementation changes:
|
||||||
|
|
||||||
|
```python
|
||||||
|
app.state.run_store = StorageRunStore(run_repository)
|
||||||
|
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
|
||||||
|
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
|
||||||
|
app.state.run_event_store = StorageRunEventStore(run_event_repository)
|
||||||
|
app.state.checkpointer = persistence.checkpointer
|
||||||
|
app.state.session_factory = persistence.session_factory
|
||||||
|
```
|
||||||
|
|
||||||
|
The app-facing objects may expose legacy method names during migration, but their internal data access should go through storage contracts.
|
||||||
|
|
||||||
|
## Boundary Rules
|
||||||
|
|
||||||
|
### Allowed Calls
|
||||||
|
|
||||||
|
Storage package callers may use:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.persistence import create_persistence_from_database_config
|
||||||
|
from store.persistence import create_persistence_from_storage_config
|
||||||
|
from store.repositories import build_run_repository
|
||||||
|
from store.repositories import build_user_repository
|
||||||
|
from store.repositories import build_thread_meta_repository
|
||||||
|
from store.repositories import build_feedback_repository
|
||||||
|
from store.repositories import build_run_event_repository
|
||||||
|
from store.repositories import RunRepositoryProtocol
|
||||||
|
from store.repositories import UserRepositoryProtocol
|
||||||
|
```
|
||||||
|
|
||||||
|
App layer callers should use:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from app.infra.storage import StorageRunRepository
|
||||||
|
from app.infra.storage import StorageUserDataRepository
|
||||||
|
from app.infra.storage import StorageThreadMetaRepository
|
||||||
|
from app.infra.storage import StorageFeedbackRepository
|
||||||
|
from app.infra.storage import StorageRunEventRepository
|
||||||
|
```
|
||||||
|
|
||||||
|
### Prohibited Calls
|
||||||
|
|
||||||
|
App/gateway/router/auth code must not import:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.repositories.db import DbRunRepository
|
||||||
|
from store.repositories.models import Run
|
||||||
|
from store.persistence.base_model import MappedBase
|
||||||
|
```
|
||||||
|
|
||||||
|
Routers must not:
|
||||||
|
|
||||||
|
- create SQLAlchemy engines
|
||||||
|
- create SQLAlchemy sessions directly
|
||||||
|
- call storage DB repository classes directly
|
||||||
|
- commit/rollback storage transactions directly unless explicitly scoped by an infra adapter
|
||||||
|
- depend on storage SQLAlchemy model classes
|
||||||
|
|
||||||
|
Storage package code must not import:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import app.gateway
|
||||||
|
import app.infra
|
||||||
|
import deerflow.runtime
|
||||||
|
```
|
||||||
|
|
||||||
|
The dependency direction is:
|
||||||
|
|
||||||
|
```text
|
||||||
|
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
|
||||||
|
```
|
||||||
|
|
||||||
|
The reverse direction is forbidden.
|
||||||
|
|
||||||
|
## Checkpointer Compatibility
|
||||||
|
|
||||||
|
The storage persistence bundle initializes the LangGraph checkpointer alongside application data persistence.
|
||||||
|
|
||||||
|
Backend-specific notes:
|
||||||
|
|
||||||
|
- SQLite uses `langgraph-checkpoint-sqlite`.
|
||||||
|
- PostgreSQL uses `langgraph-checkpoint-postgres` and requires a string `postgresql://...` connection URL.
|
||||||
|
- MySQL uses `langgraph-checkpoint-mysql` and requires a string MySQL connection URL.
|
||||||
|
|
||||||
|
SQLAlchemy may use async driver URLs such as `postgresql+asyncpg://...` or `mysql+aiomysql://...`, but LangGraph checkpointer constructors expect plain string connection URLs. This conversion belongs inside the storage driver implementation.
|
||||||
|
|
||||||
|
## JSON Metadata Filtering
|
||||||
|
|
||||||
|
Thread metadata search supports dialect-aware JSON filtering through `store.persistence.json_compat`.
|
||||||
|
|
||||||
|
The matcher supports:
|
||||||
|
|
||||||
|
- `None`
|
||||||
|
- `bool`
|
||||||
|
- `int`
|
||||||
|
- `float`
|
||||||
|
- `str`
|
||||||
|
|
||||||
|
It rejects:
|
||||||
|
|
||||||
|
- unsafe keys
|
||||||
|
- nested JSON path expressions
|
||||||
|
- dict/list values
|
||||||
|
- integers outside signed 64-bit range
|
||||||
|
|
||||||
|
This prevents SQL/JSON path injection, avoids compiled-cache type drift, and preserves type semantics such as `True != 1` and explicit JSON `null` not matching a missing key.
|
||||||
|
|
||||||
|
## Step-by-Step Implementation Plan
|
||||||
|
|
||||||
|
### Step 1: Introduce Storage Package Foundation
|
||||||
|
|
||||||
|
- Add `backend/packages/storage`.
|
||||||
|
- Add storage config models.
|
||||||
|
- Add `AppPersistence`.
|
||||||
|
- Add SQLite/PostgreSQL/MySQL persistence drivers.
|
||||||
|
- Add repository contracts, models, DB implementations, and factory helpers.
|
||||||
|
- Add package dependency wiring.
|
||||||
|
- Exclude cron persistence.
|
||||||
|
|
||||||
|
### Step 2: Harden Storage Backend Compatibility
|
||||||
|
|
||||||
|
- Validate SQLite setup and repository behavior.
|
||||||
|
- Validate PostgreSQL and MySQL with local E2E tests.
|
||||||
|
- Fix checkpointer connection-string compatibility.
|
||||||
|
- Fix PostgreSQL locking and aggregation differences.
|
||||||
|
- Add dialect-aware JSON metadata filtering.
|
||||||
|
|
||||||
|
### Step 3: Add App Infra Adapters
|
||||||
|
|
||||||
|
- Add `backend/app/infra/storage`.
|
||||||
|
- Implement app-facing repositories that own session lifecycle.
|
||||||
|
- Keep storage contracts as the only data access boundary.
|
||||||
|
- Add legacy compatibility adapters for existing app/gateway method shapes.
|
||||||
|
- Keep app/gateway imports out of `packages/storage`.
|
||||||
|
|
||||||
|
### Step 4: Switch FastAPI Runtime Injection
|
||||||
|
|
||||||
|
- Initialize storage persistence in FastAPI startup/lifespan.
|
||||||
|
- Attach `persistence`, `checkpointer`, and `session_factory` to `app.state`.
|
||||||
|
- Preserve existing external state names:
|
||||||
|
- `run_store`
|
||||||
|
- `feedback_repo`
|
||||||
|
- `thread_store`
|
||||||
|
- `run_event_store`
|
||||||
|
- `checkpointer`
|
||||||
|
- `session_factory`
|
||||||
|
- Start with user/auth provider construction, then migrate run/thread/feedback/run_event.
|
||||||
|
|
||||||
|
### Step 5: Router and Auth Compatibility
|
||||||
|
|
||||||
|
- Ensure routers consume app-facing adapters, not storage DB classes.
|
||||||
|
- Ensure auth providers depend on user repository contracts.
|
||||||
|
- Keep router response shapes unchanged.
|
||||||
|
- Add focused auth/admin/router regression tests.
|
||||||
|
|
||||||
|
### Step 6: Cleanup Legacy Persistence
|
||||||
|
|
||||||
|
- Compare old persistence usage after app/gateway migration.
|
||||||
|
- Remove unused old repository implementations only after all call sites move.
|
||||||
|
- Keep compatibility shims only where needed for a transition window.
|
||||||
|
- Delete memory backend paths from storage-owned durable persistence.
|
||||||
|
|
||||||
|
## Testing Strategy
|
||||||
|
|
||||||
|
Unit tests should cover:
|
||||||
|
|
||||||
|
- config parsing
|
||||||
|
- persistence setup
|
||||||
|
- table creation
|
||||||
|
- repository CRUD/query behavior
|
||||||
|
- typed JSON metadata filtering
|
||||||
|
- dialect SQL compilation
|
||||||
|
- cron exclusion
|
||||||
|
|
||||||
|
E2E tests should cover:
|
||||||
|
|
||||||
|
- SQLite persistence setup
|
||||||
|
- PostgreSQL temporary database setup
|
||||||
|
- MySQL temporary database setup
|
||||||
|
- repository contract behavior across all supported SQL backends
|
||||||
|
- JSON/Unicode round trip
|
||||||
|
- rollback behavior
|
||||||
|
- persistence close/cleanup
|
||||||
|
|
||||||
|
E2E tests may remain local-only if CI does not provide PostgreSQL/MySQL services.
|
||||||
@@ -0,0 +1,401 @@
|
|||||||
|
# Storage Package 设计文档
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
DeerFlow 当前有多类持久化职责分散在 app、gateway、runtime 和旧 persistence 模块中。这会带来几个问题:
|
||||||
|
|
||||||
|
- routers 和 runtime services 容易依赖具体 persistence 实现,而不是稳定契约。
|
||||||
|
- user/auth、run metadata、thread metadata、feedback、run events、checkpointer setup 的初始化路径不统一。
|
||||||
|
- memory、SQLite、PostgreSQL 相关路径中存在部分重复逻辑。
|
||||||
|
- app 层代码和 storage 层代码耦合,导致增量迁移困难。
|
||||||
|
- 增加或验证新的 SQL backend 时,需要改动 app/runtime,而不是只改 storage package。
|
||||||
|
|
||||||
|
引入 storage package 的目标,是把应用数据持久化抽象成 package 级能力,并提供明确契约、清晰边界和 SQL backend 兼容性。
|
||||||
|
|
||||||
|
## 目标
|
||||||
|
|
||||||
|
- 新增独立的 `packages/storage`,负责 durable application data。
|
||||||
|
- 通过统一 persistence 构造流程支持 SQLite、PostgreSQL、MySQL。
|
||||||
|
- 保持 LangGraph checkpointer 与同一个数据库 backend 兼容。
|
||||||
|
- 将 repository contracts 作为 package 对外唯一数据访问边界。
|
||||||
|
- app 层通过 `app.infra.storage` 适配 storage,而不是直接依赖 storage DB 实现类。
|
||||||
|
- 支持 app/gateway 后续小步迁移,避免一次性大重构。
|
||||||
|
|
||||||
|
## 非目标
|
||||||
|
|
||||||
|
- 第一阶段不删除旧 persistence。
|
||||||
|
- 不让 routers 直接依赖 storage package models。
|
||||||
|
- 不让 app routers 管理 SQLAlchemy sessions。
|
||||||
|
- cron persistence 不属于 storage package 基础迁移范围。
|
||||||
|
- memory backend 不属于 durable storage package。若 app runtime 仍需要 memory 兼容,应放在 `packages/storage` 之外。
|
||||||
|
|
||||||
|
## Storage 设计理念
|
||||||
|
|
||||||
|
### Package 自己负责 Durable Storage
|
||||||
|
|
||||||
|
`packages/storage` 负责应用数据的 durable persistence,包括:
|
||||||
|
|
||||||
|
- storage 持久化配置
|
||||||
|
- SQLAlchemy models
|
||||||
|
- repository contracts 和 DTOs
|
||||||
|
- SQL repository 实现
|
||||||
|
- persistence factory functions
|
||||||
|
- 面向现有 config 的兼容初始化入口
|
||||||
|
|
||||||
|
该 package 不应该 import `app.gateway`、routers、auth providers 或 runtime 中的 gateway 对象。
|
||||||
|
|
||||||
|
### SQL Backend 兼容
|
||||||
|
|
||||||
|
该 package 支持三种 SQL backend:
|
||||||
|
|
||||||
|
- SQLite:本地或单节点部署
|
||||||
|
- PostgreSQL:生产多节点部署
|
||||||
|
- MySQL:使用 MySQL 作为标准数据库的部署
|
||||||
|
|
||||||
|
backend 差异在 storage package 内部处理:
|
||||||
|
|
||||||
|
- SQLAlchemy async engine URL 构造
|
||||||
|
- LangGraph checkpointer 连接串兼容
|
||||||
|
- SQLite/PostgreSQL/MySQL 的 JSON metadata filter
|
||||||
|
- 不同 SQL 方言在 locking、aggregation、JSON 类型语义上的差异
|
||||||
|
|
||||||
|
### 统一 Persistence Bundle
|
||||||
|
|
||||||
|
Storage 初始化返回 `AppPersistence` bundle:
|
||||||
|
|
||||||
|
```python
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AppPersistence:
|
||||||
|
checkpointer: Checkpointer
|
||||||
|
engine: AsyncEngine
|
||||||
|
session_factory: async_sessionmaker[AsyncSession]
|
||||||
|
setup: Callable[[], Awaitable[None]]
|
||||||
|
aclose: Callable[[], Awaitable[None]]
|
||||||
|
```
|
||||||
|
|
||||||
|
app runtime 只需要初始化一次 persistence,调用 `setup()`,然后注入:
|
||||||
|
|
||||||
|
- `checkpointer`
|
||||||
|
- `session_factory`
|
||||||
|
- repository adapters
|
||||||
|
|
||||||
|
这样 checkpointer 和应用数据可以对齐到同一个 backend,同时 routers 不需要理解数据库配置。
|
||||||
|
|
||||||
|
## Package 结构
|
||||||
|
|
||||||
|
```text
|
||||||
|
backend/packages/storage/
|
||||||
|
store/
|
||||||
|
config/
|
||||||
|
storage_config.py
|
||||||
|
app_config.py
|
||||||
|
persistence/
|
||||||
|
factory.py
|
||||||
|
types.py
|
||||||
|
base_model.py
|
||||||
|
json_compat.py
|
||||||
|
drivers/
|
||||||
|
sqlite.py
|
||||||
|
postgres.py
|
||||||
|
mysql.py
|
||||||
|
repositories/
|
||||||
|
contracts/
|
||||||
|
user.py
|
||||||
|
run.py
|
||||||
|
thread_meta.py
|
||||||
|
feedback.py
|
||||||
|
run_event.py
|
||||||
|
models/
|
||||||
|
user.py
|
||||||
|
run.py
|
||||||
|
thread_meta.py
|
||||||
|
feedback.py
|
||||||
|
run_event.py
|
||||||
|
db/
|
||||||
|
user.py
|
||||||
|
run.py
|
||||||
|
thread_meta.py
|
||||||
|
feedback.py
|
||||||
|
run_event.py
|
||||||
|
factory.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## Persistence 构造
|
||||||
|
|
||||||
|
storage 的主要入口:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.persistence import create_persistence_from_storage_config
|
||||||
|
|
||||||
|
persistence = await create_persistence_from_storage_config(storage_config)
|
||||||
|
await persistence.setup()
|
||||||
|
```
|
||||||
|
|
||||||
|
为了兼容现有 app database config,也提供:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.persistence import create_persistence_from_database_config
|
||||||
|
|
||||||
|
persistence = await create_persistence_from_database_config(config.database)
|
||||||
|
await persistence.setup()
|
||||||
|
```
|
||||||
|
|
||||||
|
预期 app startup 流程:
|
||||||
|
|
||||||
|
```python
|
||||||
|
persistence = await create_persistence_from_database_config(config.database)
|
||||||
|
await persistence.setup()
|
||||||
|
|
||||||
|
app.state.persistence = persistence
|
||||||
|
app.state.checkpointer = persistence.checkpointer
|
||||||
|
app.state.session_factory = persistence.session_factory
|
||||||
|
```
|
||||||
|
|
||||||
|
预期 app shutdown 流程:
|
||||||
|
|
||||||
|
```python
|
||||||
|
await app.state.persistence.aclose()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Repository 契约设计
|
||||||
|
|
||||||
|
Repository contracts 是 storage package 对外公开的数据访问边界。它们位于 `store.repositories.contracts`,并通过 `store.repositories` re-export。
|
||||||
|
|
||||||
|
主要契约包括:
|
||||||
|
|
||||||
|
- `UserRepositoryProtocol`
|
||||||
|
- `RunRepositoryProtocol`
|
||||||
|
- `ThreadMetaRepositoryProtocol`
|
||||||
|
- `FeedbackRepositoryProtocol`
|
||||||
|
- `RunEventRepositoryProtocol`
|
||||||
|
|
||||||
|
每组契约包含:
|
||||||
|
|
||||||
|
- 输入 DTO,例如 `UserCreate`、`RunCreate`、`ThreadMetaCreate`
|
||||||
|
- 输出 DTO,例如 `User`、`Run`、`ThreadMeta`
|
||||||
|
- repository protocol methods
|
||||||
|
- 必要的领域异常,例如 `InvalidMetadataFilterError`
|
||||||
|
|
||||||
|
Repository 通过 session 构造:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.repositories import build_run_repository
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_run_repository(session)
|
||||||
|
run = await repo.get_run(run_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
这样可以让 transaction ownership 保持明确。storage package 不通过全局 singleton 隐式隐藏 commit 或 session 生命周期。
|
||||||
|
|
||||||
|
## App/Infra 调用契约
|
||||||
|
|
||||||
|
app 层不应该直接调用 `store.repositories.db.*`。预期的 app 边界是 `app.infra.storage`。
|
||||||
|
|
||||||
|
`app.infra.storage` 负责:
|
||||||
|
|
||||||
|
- 从 FastAPI runtime 初始化中接收 `session_factory`
|
||||||
|
- 为 app-facing repository methods 管理 session 生命周期
|
||||||
|
- 在必要时将 storage DTOs 转成 app/gateway DTOs
|
||||||
|
- 迁移期间保留现有 app-facing 名称
|
||||||
|
- 依赖 storage repository protocols,而不是具体 DB classes
|
||||||
|
|
||||||
|
预期 adapter 模式:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class StorageRunRepository(RunRepositoryProtocol):
|
||||||
|
def __init__(self, session_factory):
|
||||||
|
self._session_factory = session_factory
|
||||||
|
|
||||||
|
async def get_run(self, run_id: str):
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
repo = build_run_repository(session)
|
||||||
|
return await repo.get_run(run_id)
|
||||||
|
```
|
||||||
|
|
||||||
|
为了兼容 gateway,app state 可以暂时保持现有名字,只替换内部实现:
|
||||||
|
|
||||||
|
```python
|
||||||
|
app.state.run_store = StorageRunStore(run_repository)
|
||||||
|
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
|
||||||
|
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
|
||||||
|
app.state.run_event_store = StorageRunEventStore(run_event_repository)
|
||||||
|
app.state.checkpointer = persistence.checkpointer
|
||||||
|
app.state.session_factory = persistence.session_factory
|
||||||
|
```
|
||||||
|
|
||||||
|
app-facing objects 可以在迁移期间保留旧方法名,但内部数据访问必须经过 storage contracts。
|
||||||
|
|
||||||
|
## 边界规则
|
||||||
|
|
||||||
|
### 允许调用的范围
|
||||||
|
|
||||||
|
storage package 调用方可以使用:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.persistence import create_persistence_from_database_config
|
||||||
|
from store.persistence import create_persistence_from_storage_config
|
||||||
|
from store.repositories import build_run_repository
|
||||||
|
from store.repositories import build_user_repository
|
||||||
|
from store.repositories import build_thread_meta_repository
|
||||||
|
from store.repositories import build_feedback_repository
|
||||||
|
from store.repositories import build_run_event_repository
|
||||||
|
from store.repositories import RunRepositoryProtocol
|
||||||
|
from store.repositories import UserRepositoryProtocol
|
||||||
|
```
|
||||||
|
|
||||||
|
app 层应该使用:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from app.infra.storage import StorageRunRepository
|
||||||
|
from app.infra.storage import StorageUserDataRepository
|
||||||
|
from app.infra.storage import StorageThreadMetaRepository
|
||||||
|
from app.infra.storage import StorageFeedbackRepository
|
||||||
|
from app.infra.storage import StorageRunEventRepository
|
||||||
|
```
|
||||||
|
|
||||||
|
### 禁止调用的范围
|
||||||
|
|
||||||
|
app/gateway/router/auth 代码不应该 import:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from store.repositories.db import DbRunRepository
|
||||||
|
from store.repositories.models import Run
|
||||||
|
from store.persistence.base_model import MappedBase
|
||||||
|
```
|
||||||
|
|
||||||
|
routers 禁止:
|
||||||
|
|
||||||
|
- 创建 SQLAlchemy engines
|
||||||
|
- 直接创建 SQLAlchemy sessions
|
||||||
|
- 直接调用 storage DB repository classes
|
||||||
|
- 直接 commit/rollback storage transactions,除非这是 infra adapter 明确管理的范围
|
||||||
|
- 依赖 storage SQLAlchemy model classes
|
||||||
|
|
||||||
|
storage package 禁止 import:
|
||||||
|
|
||||||
|
```python
|
||||||
|
import app.gateway
|
||||||
|
import app.infra
|
||||||
|
import deerflow.runtime
|
||||||
|
```
|
||||||
|
|
||||||
|
依赖方向必须是:
|
||||||
|
|
||||||
|
```text
|
||||||
|
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
|
||||||
|
```
|
||||||
|
|
||||||
|
禁止反向依赖。
|
||||||
|
|
||||||
|
## Checkpointer 兼容
|
||||||
|
|
||||||
|
storage persistence bundle 会同时初始化 LangGraph checkpointer 和应用数据持久化。
|
||||||
|
|
||||||
|
backend 说明:
|
||||||
|
|
||||||
|
- SQLite 使用 `langgraph-checkpoint-sqlite`。
|
||||||
|
- PostgreSQL 使用 `langgraph-checkpoint-postgres`,需要字符串形式的 `postgresql://...` 连接串。
|
||||||
|
- MySQL 使用 `langgraph-checkpoint-mysql`,需要字符串形式的 MySQL 连接串。
|
||||||
|
|
||||||
|
SQLAlchemy 可以使用 `postgresql+asyncpg://...` 或 `mysql+aiomysql://...` 这类 async driver URL,但 LangGraph checkpointer 构造函数需要普通字符串连接串。这个转换应该封装在 storage driver implementation 内部。
|
||||||
|
|
||||||
|
## JSON Metadata Filtering
|
||||||
|
|
||||||
|
Thread metadata search 通过 `store.persistence.json_compat` 支持跨方言 JSON filtering。
|
||||||
|
|
||||||
|
支持的 filter value 类型:
|
||||||
|
|
||||||
|
- `None`
|
||||||
|
- `bool`
|
||||||
|
- `int`
|
||||||
|
- `float`
|
||||||
|
- `str`
|
||||||
|
|
||||||
|
拒绝:
|
||||||
|
|
||||||
|
- unsafe keys
|
||||||
|
- nested JSON path expressions
|
||||||
|
- dict/list values
|
||||||
|
- 超出 signed 64-bit 范围的整数
|
||||||
|
|
||||||
|
这样可以避免 SQL/JSON path injection,避免 compiled-cache 类型漂移,并保留类型语义,例如 `True != 1`,显式 JSON `null` 不等于 missing key。
|
||||||
|
|
||||||
|
## 分步实现方案
|
||||||
|
|
||||||
|
### 第 1 步:新增 Storage Package 基础
|
||||||
|
|
||||||
|
- 新增 `backend/packages/storage`。
|
||||||
|
- 增加 storage config models。
|
||||||
|
- 增加 `AppPersistence`。
|
||||||
|
- 增加 SQLite/PostgreSQL/MySQL persistence drivers。
|
||||||
|
- 增加 repository contracts、models、DB implementations 和 factory helpers。
|
||||||
|
- 接入 package dependency。
|
||||||
|
- 排除 cron persistence。
|
||||||
|
|
||||||
|
### 第 2 步:补齐 Storage Backend 兼容性
|
||||||
|
|
||||||
|
- 验证 SQLite setup 和 repository 行为。
|
||||||
|
- 使用本地 E2E 验证 PostgreSQL 和 MySQL。
|
||||||
|
- 修复 checkpointer 连接串兼容。
|
||||||
|
- 修复 PostgreSQL locking 和 aggregation 差异。
|
||||||
|
- 增加跨方言 JSON metadata filtering。
|
||||||
|
|
||||||
|
### 第 3 步:新增 App Infra Adapters
|
||||||
|
|
||||||
|
- 新增 `backend/app/infra/storage`。
|
||||||
|
- 实现 app-facing repositories,由它们管理 session 生命周期。
|
||||||
|
- 保持 storage contracts 作为唯一数据访问边界。
|
||||||
|
- 为现有 app/gateway method shape 增加兼容 adapters。
|
||||||
|
- 避免 `packages/storage` import app/gateway。
|
||||||
|
|
||||||
|
### 第 4 步:切换 FastAPI Runtime 注入
|
||||||
|
|
||||||
|
- 在 FastAPI startup/lifespan 中初始化 storage persistence。
|
||||||
|
- 将 `persistence`、`checkpointer`、`session_factory` 注入 `app.state`。
|
||||||
|
- 暂时保留现有对外 state 名称:
|
||||||
|
- `run_store`
|
||||||
|
- `feedback_repo`
|
||||||
|
- `thread_store`
|
||||||
|
- `run_event_store`
|
||||||
|
- `checkpointer`
|
||||||
|
- `session_factory`
|
||||||
|
- 先切 user/auth provider 构造,再逐步迁移 run/thread/feedback/run_event。
|
||||||
|
|
||||||
|
### 第 5 步:Router 和 Auth 兼容
|
||||||
|
|
||||||
|
- 确保 routers 消费 app-facing adapters,而不是 storage DB classes。
|
||||||
|
- 确保 auth providers 依赖 user repository contracts。
|
||||||
|
- 保持 router response shapes 不变。
|
||||||
|
- 增加 auth/admin/router regression tests。
|
||||||
|
|
||||||
|
### 第 6 步:清理旧 Persistence
|
||||||
|
|
||||||
|
- app/gateway 迁移完成后,再比较旧 persistence usage。
|
||||||
|
- 所有 call sites 迁移完成后,再删除未使用的旧 repository implementations。
|
||||||
|
- 只在必要时保留短期 compatibility shims。
|
||||||
|
- 从 storage-owned durable persistence 中移除 memory backend 路径。
|
||||||
|
|
||||||
|
## 测试策略
|
||||||
|
|
||||||
|
单测应覆盖:
|
||||||
|
|
||||||
|
- config parsing
|
||||||
|
- persistence setup
|
||||||
|
- table creation
|
||||||
|
- repository CRUD/query behavior
|
||||||
|
- typed JSON metadata filtering
|
||||||
|
- dialect SQL compilation
|
||||||
|
- cron exclusion
|
||||||
|
|
||||||
|
E2E 应覆盖:
|
||||||
|
|
||||||
|
- SQLite persistence setup
|
||||||
|
- PostgreSQL temporary database setup
|
||||||
|
- MySQL temporary database setup
|
||||||
|
- 所有支持 SQL backend 下的 repository contract 行为
|
||||||
|
- JSON/Unicode round trip
|
||||||
|
- rollback behavior
|
||||||
|
- persistence close/cleanup
|
||||||
|
|
||||||
|
如果 CI 暂时没有 PostgreSQL/MySQL services,E2E 可以先作为 local-only 验证保留。
|
||||||
@@ -35,7 +35,7 @@ def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
|
|||||||
if app_config is None:
|
if app_config is None:
|
||||||
try:
|
try:
|
||||||
app_config = get_app_config()
|
app_config = get_app_config()
|
||||||
except FileNotFoundError:
|
except (FileNotFoundError, ValueError):
|
||||||
return False
|
return False
|
||||||
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
|
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,35 @@
|
|||||||
|
[project]
|
||||||
|
name = "deerflow-storage"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "DeerFlow storage framework"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"dotenv>=0.9.9",
|
||||||
|
"pydantic>=2.12.5",
|
||||||
|
"pyyaml>=6.0.3",
|
||||||
|
"sqlalchemy[asyncio]>=2.0,<3.0",
|
||||||
|
"alembic>=1.13",
|
||||||
|
"langgraph>=1.1.9",
|
||||||
|
]
|
||||||
|
[project.optional-dependencies]
|
||||||
|
postgres = [
|
||||||
|
"asyncpg>=0.29",
|
||||||
|
"langgraph-checkpoint-postgres>=3.0.5",
|
||||||
|
"psycopg[binary]>=3.3.3",
|
||||||
|
"psycopg-pool>=3.3.0",
|
||||||
|
]
|
||||||
|
mysql = [
|
||||||
|
"aiomysql>=0.2",
|
||||||
|
"langgraph-checkpoint-mysql>=3.0.0",
|
||||||
|
]
|
||||||
|
sqlite = [
|
||||||
|
"aiosqlite>=0.22.1",
|
||||||
|
"langgraph-checkpoint-sqlite>=3.0.3"
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["store"]
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
from .enums import DataBaseType
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DataBaseType",
|
||||||
|
]
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
from enum import Enum
|
||||||
|
from enum import IntEnum as SourceIntEnum
|
||||||
|
from enum import StrEnum as SourceStrEnum
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
|
T = TypeVar("T", bound=Enum)
|
||||||
|
|
||||||
|
|
||||||
|
class _EnumBase:
|
||||||
|
"""Base enum class with common utility methods."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_member_keys(cls) -> list[str]:
|
||||||
|
"""Return a list of enum member names."""
|
||||||
|
return list(cls.__members__.keys())
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_member_values(cls) -> list:
|
||||||
|
"""Return a list of enum member values."""
|
||||||
|
return [item.value for item in cls.__members__.values()]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_member_dict(cls) -> dict[str, Any]:
|
||||||
|
"""Return a dict mapping member names to values."""
|
||||||
|
return {name: item.value for name, item in cls.__members__.items()}
|
||||||
|
|
||||||
|
|
||||||
|
class IntEnum(_EnumBase, SourceIntEnum):
|
||||||
|
"""Integer enum base class."""
|
||||||
|
|
||||||
|
|
||||||
|
class StrEnum(_EnumBase, SourceStrEnum):
|
||||||
|
"""String enum base class."""
|
||||||
|
|
||||||
|
|
||||||
|
class DataBaseType(StrEnum):
|
||||||
|
"""Database type."""
|
||||||
|
|
||||||
|
sqlite = "sqlite"
|
||||||
|
mysql = "mysql"
|
||||||
|
postgresql = "postgresql"
|
||||||
@@ -0,0 +1,286 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Self
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
from store.config.storage_config import StorageConfig
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _default_config_candidates() -> tuple[Path, ...]:
|
||||||
|
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||||
|
backend_dir = Path(__file__).resolve().parents[4]
|
||||||
|
repo_root = backend_dir.parent
|
||||||
|
cwd = Path.cwd().resolve()
|
||||||
|
candidates = (
|
||||||
|
cwd / "config.yaml",
|
||||||
|
backend_dir / "config.yaml",
|
||||||
|
repo_root / "config.yaml",
|
||||||
|
)
|
||||||
|
return tuple(dict.fromkeys(candidates))
|
||||||
|
|
||||||
|
|
||||||
|
def _storage_from_database_config(config_data: dict[str, Any]) -> None:
|
||||||
|
"""Keep the existing public `database:` config compatible with storage."""
|
||||||
|
if "storage" in config_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
database = config_data.get("database")
|
||||||
|
if not isinstance(database, dict):
|
||||||
|
return
|
||||||
|
|
||||||
|
backend = database.get("backend")
|
||||||
|
if backend == "memory":
|
||||||
|
raise ValueError("database.backend='memory' is not supported by storage; handle memory mode before loading storage config")
|
||||||
|
|
||||||
|
storage: dict[str, Any] = {
|
||||||
|
"driver": "postgres" if backend == "postgres" else backend,
|
||||||
|
"sqlite_dir": database.get("sqlite_dir", ".deer-flow/data"),
|
||||||
|
"echo_sql": database.get("echo_sql", False),
|
||||||
|
"pool_size": database.get("pool_size", 5),
|
||||||
|
}
|
||||||
|
|
||||||
|
postgres_url = database.get("postgres_url")
|
||||||
|
if backend == "postgres" and isinstance(postgres_url, str) and postgres_url:
|
||||||
|
from sqlalchemy.engine.url import make_url
|
||||||
|
|
||||||
|
parsed = make_url(postgres_url)
|
||||||
|
storage["database_url"] = postgres_url
|
||||||
|
storage.update(
|
||||||
|
{
|
||||||
|
"username": parsed.username or "",
|
||||||
|
"password": parsed.password or "",
|
||||||
|
"host": parsed.host or "localhost",
|
||||||
|
"port": parsed.port or 5432,
|
||||||
|
"db_name": parsed.database or "deerflow",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
config_data["storage"] = storage
|
||||||
|
|
||||||
|
|
||||||
|
class AppConfig(BaseModel):
|
||||||
|
"""DeerFlow application configuration."""
|
||||||
|
|
||||||
|
timezone: str = Field(default="UTC", description="Timezone for scheduling and timestamps (e.g. 'UTC', 'America/New_York')")
|
||||||
|
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
|
||||||
|
storage: StorageConfig = Field(default=StorageConfig())
|
||||||
|
model_config = ConfigDict(extra="allow", frozen=False)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||||
|
"""Resolve the config file path.
|
||||||
|
|
||||||
|
Priority:
|
||||||
|
1. If provided `config_path` argument, use it.
|
||||||
|
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||||
|
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||||
|
"""
|
||||||
|
if config_path:
|
||||||
|
path = Path(config_path)
|
||||||
|
if not Path.exists(path):
|
||||||
|
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
|
||||||
|
return path
|
||||||
|
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
|
||||||
|
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
|
||||||
|
if not Path.exists(path):
|
||||||
|
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||||
|
return path
|
||||||
|
else:
|
||||||
|
for path in _default_config_candidates():
|
||||||
|
if path.exists():
|
||||||
|
return path
|
||||||
|
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, config_path: str | None = None) -> Self:
|
||||||
|
"""Load and validate config from YAML. See `resolve_config_path` for path resolution."""
|
||||||
|
resolved_path = cls.resolve_config_path(config_path)
|
||||||
|
with open(resolved_path, encoding="utf-8") as f:
|
||||||
|
config_data = yaml.safe_load(f) or {}
|
||||||
|
|
||||||
|
cls._check_config_version(config_data, resolved_path)
|
||||||
|
|
||||||
|
config_data = cls.resolve_env_variables(config_data)
|
||||||
|
_storage_from_database_config(config_data)
|
||||||
|
|
||||||
|
if os.getenv("TIMEZONE"):
|
||||||
|
config_data["timezone"] = os.getenv("TIMEZONE")
|
||||||
|
|
||||||
|
result = cls.model_validate(config_data)
|
||||||
|
return result
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
|
||||||
|
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
|
||||||
|
|
||||||
|
Emits a warning if the user's config_version is lower than the example's.
|
||||||
|
Missing config_version is treated as version 0 (pre-versioning).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
user_version = int(config_data.get("config_version", 0))
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
user_version = 0
|
||||||
|
|
||||||
|
# Find config.example.yaml by searching config.yaml's directory and its parents
|
||||||
|
example_path = None
|
||||||
|
search_dir = config_path.parent
|
||||||
|
for _ in range(5): # search up to 5 levels
|
||||||
|
candidate = search_dir / "config.example.yaml"
|
||||||
|
if candidate.exists():
|
||||||
|
example_path = candidate
|
||||||
|
break
|
||||||
|
parent = search_dir.parent
|
||||||
|
if parent == search_dir:
|
||||||
|
break
|
||||||
|
search_dir = parent
|
||||||
|
if example_path is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(example_path, encoding="utf-8") as f:
|
||||||
|
example_data = yaml.safe_load(f)
|
||||||
|
raw = example_data.get("config_version", 0) if example_data else 0
|
||||||
|
try:
|
||||||
|
example_version = int(raw)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
example_version = 0
|
||||||
|
except Exception:
|
||||||
|
return
|
||||||
|
|
||||||
|
if user_version < example_version:
|
||||||
|
logger.warning(
|
||||||
|
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.",
|
||||||
|
user_version,
|
||||||
|
example_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def resolve_env_variables(cls, config: Any) -> Any:
|
||||||
|
"""Recursively replace $VAR strings with their environment variable values (e.g. $OPENAI_API_KEY)."""
|
||||||
|
if isinstance(config, str):
|
||||||
|
if config.startswith("$"):
|
||||||
|
env_value = os.getenv(config[1:])
|
||||||
|
if env_value is None:
|
||||||
|
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
|
||||||
|
return env_value
|
||||||
|
return config
|
||||||
|
elif isinstance(config, dict):
|
||||||
|
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
|
||||||
|
elif isinstance(config, list):
|
||||||
|
return [cls.resolve_env_variables(item) for item in config]
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
_app_config: AppConfig | None = None
|
||||||
|
_app_config_path: Path | None = None
|
||||||
|
_app_config_mtime: float | None = None
|
||||||
|
_app_config_is_custom = False
|
||||||
|
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||||
|
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
|
||||||
|
|
||||||
|
|
||||||
|
def _get_config_mtime(config_path: Path) -> float | None:
|
||||||
|
"""Get the modification time of a config file if it exists."""
|
||||||
|
try:
|
||||||
|
return config_path.stat().st_mtime
|
||||||
|
except OSError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
|
||||||
|
"""Load config from disk and refresh cache metadata."""
|
||||||
|
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||||
|
|
||||||
|
resolved_path = AppConfig.resolve_config_path(config_path)
|
||||||
|
_app_config = AppConfig.from_file(str(resolved_path))
|
||||||
|
_app_config_path = resolved_path
|
||||||
|
_app_config_mtime = _get_config_mtime(resolved_path)
|
||||||
|
_app_config_is_custom = False
|
||||||
|
return _app_config
|
||||||
|
|
||||||
|
|
||||||
|
def get_app_config() -> AppConfig:
|
||||||
|
"""Get the DeerFlow config instance.
|
||||||
|
|
||||||
|
Returns a cached singleton instance and automatically reloads it when the
|
||||||
|
underlying config file path or modification time changes. Use
|
||||||
|
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
|
||||||
|
the cache.
|
||||||
|
"""
|
||||||
|
global _app_config, _app_config_path, _app_config_mtime
|
||||||
|
|
||||||
|
runtime_override = _current_app_config.get()
|
||||||
|
if runtime_override is not None:
|
||||||
|
return runtime_override
|
||||||
|
|
||||||
|
if _app_config is not None and _app_config_is_custom:
|
||||||
|
return _app_config
|
||||||
|
|
||||||
|
resolved_path = AppConfig.resolve_config_path()
|
||||||
|
current_mtime = _get_config_mtime(resolved_path)
|
||||||
|
|
||||||
|
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
|
||||||
|
if should_reload:
|
||||||
|
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
|
||||||
|
logger.info(
|
||||||
|
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
|
||||||
|
_app_config_mtime,
|
||||||
|
current_mtime,
|
||||||
|
)
|
||||||
|
_load_and_cache_app_config(str(resolved_path))
|
||||||
|
return _app_config
|
||||||
|
|
||||||
|
|
||||||
|
def reload_app_config(config_path: str | None = None) -> AppConfig:
|
||||||
|
"""Force reload from file and update the cache."""
|
||||||
|
return _load_and_cache_app_config(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def reset_app_config() -> None:
|
||||||
|
"""Clear the cache so the next `get_app_config()` reloads from file."""
|
||||||
|
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||||
|
_app_config = None
|
||||||
|
_app_config_path = None
|
||||||
|
_app_config_mtime = None
|
||||||
|
_app_config_is_custom = False
|
||||||
|
|
||||||
|
|
||||||
|
def set_app_config(config: AppConfig) -> None:
|
||||||
|
"""Inject a config instance directly, bypassing file loading (for testing)."""
|
||||||
|
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||||
|
_app_config = config
|
||||||
|
_app_config_path = None
|
||||||
|
_app_config_mtime = None
|
||||||
|
_app_config_is_custom = True
|
||||||
|
|
||||||
|
|
||||||
|
def peek_current_app_config() -> AppConfig | None:
|
||||||
|
"""Return the runtime-scoped AppConfig override, if one is active."""
|
||||||
|
return _current_app_config.get()
|
||||||
|
|
||||||
|
|
||||||
|
def push_current_app_config(config: AppConfig) -> None:
|
||||||
|
"""Push a runtime-scoped AppConfig override for the current execution context."""
|
||||||
|
stack = _current_app_config_stack.get()
|
||||||
|
_current_app_config_stack.set(stack + (_current_app_config.get(),))
|
||||||
|
_current_app_config.set(config)
|
||||||
|
|
||||||
|
|
||||||
|
def pop_current_app_config() -> None:
|
||||||
|
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
|
||||||
|
stack = _current_app_config_stack.get()
|
||||||
|
if not stack:
|
||||||
|
_current_app_config.set(None)
|
||||||
|
return
|
||||||
|
previous = stack[-1]
|
||||||
|
_current_app_config_stack.set(stack[:-1])
|
||||||
|
_current_app_config.set(previous)
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
"""Unified storage backend configuration for checkpointer and application data.
|
||||||
|
|
||||||
|
SQLite: checkpointer → {sqlite_dir}/checkpoints.db, app → {sqlite_dir}/deerflow.db
|
||||||
|
(separate files to avoid write-lock contention)
|
||||||
|
Postgres: shared URL, independent connection pools per layer.
|
||||||
|
|
||||||
|
Sensitive values use $VAR syntax resolved by AppConfig.resolve_env_variables()
|
||||||
|
before this config is instantiated.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_legacy_state_prefix(path: str) -> str:
|
||||||
|
"""Keep old .deer-flow/* config values compatible with Paths.base_dir."""
|
||||||
|
prefix = ".deer-flow/"
|
||||||
|
if path == ".deer-flow":
|
||||||
|
return "."
|
||||||
|
if path.startswith(prefix):
|
||||||
|
return path[len(prefix) :]
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
class StorageConfig(BaseModel):
|
||||||
|
driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field(
|
||||||
|
default="sqlite",
|
||||||
|
description="Storage driver for both checkpointer and application data. 'sqlite' for single-node deployment (default),'postgres' for production multi-node deployment, 'mysql' for MySQL databases.",
|
||||||
|
)
|
||||||
|
sqlite_dir: str = Field(
|
||||||
|
default=".deer-flow/data",
|
||||||
|
description="Directory for SQLite .db files (sqlite driver only).",
|
||||||
|
)
|
||||||
|
username: str = Field(default="", description="db username ")
|
||||||
|
password: str = Field(default="", description="db password. Use $VAR syntax in config.yaml to read from .env.")
|
||||||
|
host: str = Field(default="localhost", description="db host.")
|
||||||
|
port: int = Field(default=5432, description="db port.")
|
||||||
|
db_name: str = Field(default="deerflow", description="db database name.")
|
||||||
|
database_url: str = Field(default="", description="Complete SQLAlchemy database URL. Takes precedence for non-SQLite drivers.")
|
||||||
|
sqlite_db_path: str = Field(default=".deer-flow/data", description="Directory for SQLite .db files (sqlite driver only).")
|
||||||
|
echo_sql: bool = Field(default=False, description="Log all SQL statements (debug only).")
|
||||||
|
pool_size: int = Field(default=5, description="Connection pool size per layer.")
|
||||||
|
|
||||||
|
# -- Derived helpers (not user-configured) --
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _resolved_sqlite_dir(self) -> str:
|
||||||
|
"""Resolve sqlite_dir to an absolute path under DeerFlow's base dir."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
path = Path(self.sqlite_dir)
|
||||||
|
if path.is_absolute():
|
||||||
|
return str(path.resolve())
|
||||||
|
|
||||||
|
try:
|
||||||
|
from deerflow.config.paths import resolve_path
|
||||||
|
|
||||||
|
return str(resolve_path(_strip_legacy_state_prefix(self.sqlite_dir)))
|
||||||
|
except ImportError:
|
||||||
|
return str(path.resolve())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sqlite_storage_path(self) -> str:
|
||||||
|
"""SQLite file path for storage-owned app data and checkpointer."""
|
||||||
|
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
||||||
@@ -0,0 +1,32 @@
|
|||||||
|
from store.persistence.base_model import (
|
||||||
|
Base,
|
||||||
|
DataClassBase,
|
||||||
|
DateTimeMixin,
|
||||||
|
MappedBase,
|
||||||
|
TimeZone,
|
||||||
|
UniversalText,
|
||||||
|
id_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
from .factory import (
|
||||||
|
create_persistence,
|
||||||
|
create_persistence_from_database_config,
|
||||||
|
create_persistence_from_storage_config,
|
||||||
|
storage_config_from_database_config,
|
||||||
|
)
|
||||||
|
from .types import AppPersistence
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Base",
|
||||||
|
"DataClassBase",
|
||||||
|
"DateTimeMixin",
|
||||||
|
"MappedBase",
|
||||||
|
"TimeZone",
|
||||||
|
"UniversalText",
|
||||||
|
"id_key",
|
||||||
|
"create_persistence",
|
||||||
|
"create_persistence_from_database_config",
|
||||||
|
"create_persistence_from_storage_config",
|
||||||
|
"storage_config_from_database_config",
|
||||||
|
"AppPersistence",
|
||||||
|
]
|
||||||
@@ -0,0 +1,111 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator
|
||||||
|
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||||
|
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
||||||
|
|
||||||
|
from store.utils import get_timezone
|
||||||
|
|
||||||
|
|
||||||
|
def current_time() -> datetime:
|
||||||
|
return get_timezone().now()
|
||||||
|
|
||||||
|
|
||||||
|
id_key = Annotated[
|
||||||
|
int,
|
||||||
|
mapped_column(
|
||||||
|
BigInteger().with_variant(Integer, "sqlite"),
|
||||||
|
primary_key=True,
|
||||||
|
unique=True,
|
||||||
|
index=True,
|
||||||
|
autoincrement=True,
|
||||||
|
sort_order=-999,
|
||||||
|
comment="Primary key ID",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class UniversalText(TypeDecorator[str]):
|
||||||
|
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
||||||
|
|
||||||
|
impl = Text
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
def load_dialect_impl(self, dialect): # noqa: ANN001
|
||||||
|
if dialect.name == "mysql":
|
||||||
|
return dialect.type_descriptor(LONGTEXT())
|
||||||
|
return dialect.type_descriptor(Text())
|
||||||
|
|
||||||
|
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||||
|
return value
|
||||||
|
|
||||||
|
def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class TimeZone(TypeDecorator[datetime]):
|
||||||
|
"""Timezone-aware datetime type compatible with PostgreSQL and MySQL."""
|
||||||
|
|
||||||
|
impl = DateTime(timezone=True)
|
||||||
|
cache_ok = True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def python_type(self) -> type[datetime]:
|
||||||
|
return datetime
|
||||||
|
|
||||||
|
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||||
|
timezone = get_timezone()
|
||||||
|
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
||||||
|
value = timezone.from_datetime(value)
|
||||||
|
return value
|
||||||
|
|
||||||
|
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||||
|
timezone = get_timezone()
|
||||||
|
if value is not None and value.tzinfo is None:
|
||||||
|
value = value.replace(tzinfo=timezone.tz_info)
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class DateTimeMixin(MappedAsDataclass):
|
||||||
|
"""Mixin that adds created_time / updated_time columns."""
|
||||||
|
|
||||||
|
created_time: Mapped[datetime] = mapped_column(
|
||||||
|
TimeZone,
|
||||||
|
init=False,
|
||||||
|
default_factory=current_time,
|
||||||
|
sort_order=999,
|
||||||
|
comment="Created at",
|
||||||
|
)
|
||||||
|
updated_time: Mapped[datetime | None] = mapped_column(
|
||||||
|
TimeZone,
|
||||||
|
init=False,
|
||||||
|
onupdate=current_time,
|
||||||
|
sort_order=999,
|
||||||
|
comment="Updated at",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MappedBase(AsyncAttrs, DeclarativeBase):
|
||||||
|
"""Async-capable declarative base for all ORM models."""
|
||||||
|
|
||||||
|
@declared_attr.directive
|
||||||
|
def __tablename__(self) -> str:
|
||||||
|
return self.__name__.lower()
|
||||||
|
|
||||||
|
@declared_attr.directive
|
||||||
|
def __table_args__(self) -> dict:
|
||||||
|
return {"comment": self.__doc__ or ""}
|
||||||
|
|
||||||
|
|
||||||
|
class DataClassBase(MappedAsDataclass, MappedBase):
|
||||||
|
"""Declarative base with native dataclass integration."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DataClassBase, DateTimeMixin):
|
||||||
|
"""Declarative dataclass base with created_time / updated_time columns."""
|
||||||
|
|
||||||
|
__abstract__ = True
|
||||||
@@ -0,0 +1,9 @@
|
|||||||
|
from .mysql import build_mysql_persistence
|
||||||
|
from .postgres import build_postgres_persistence
|
||||||
|
from .sqlite import build_sqlite_persistence
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"build_postgres_persistence",
|
||||||
|
"build_mysql_persistence",
|
||||||
|
"build_sqlite_persistence",
|
||||||
|
]
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from sqlalchemy import URL
|
||||||
|
from sqlalchemy.engine import make_url
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from store.persistence import MappedBase
|
||||||
|
from store.persistence.shared import close_in_order
|
||||||
|
from store.persistence.types import AppPersistence
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_mysql_driver(db_url: URL) -> str:
|
||||||
|
url = make_url(db_url)
|
||||||
|
driver = url.get_driver_name()
|
||||||
|
|
||||||
|
if driver not in {"aiomysql", "asyncmy"}:
|
||||||
|
raise ValueError(f"MySQL persistence requires async SQLAlchemy driver (aiomysql/asyncmy), got: {driver!r}")
|
||||||
|
return driver
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||||
|
return db_url.render_as_string(hide_password=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||||
|
_validate_mysql_driver(db_url)
|
||||||
|
|
||||||
|
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
|
||||||
|
|
||||||
|
import store.repositories.models # noqa: F401
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
db_url,
|
||||||
|
echo=echo,
|
||||||
|
future=True,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_size=pool_size,
|
||||||
|
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
session_factory = async_sessionmaker(
|
||||||
|
bind=engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||||
|
checkpointer = await saver_cm.__aenter__()
|
||||||
|
|
||||||
|
async def setup() -> None:
|
||||||
|
# 1. LangGraph checkpoint tables / migrations
|
||||||
|
await checkpointer.setup()
|
||||||
|
|
||||||
|
# 2. ORM business tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(MappedBase.metadata.create_all)
|
||||||
|
|
||||||
|
async def _close_saver() -> None:
|
||||||
|
await saver_cm.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
async def aclose() -> None:
|
||||||
|
await close_in_order(
|
||||||
|
engine.dispose,
|
||||||
|
_close_saver,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AppPersistence(
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
engine=engine,
|
||||||
|
session_factory=session_factory,
|
||||||
|
setup=setup,
|
||||||
|
aclose=aclose,
|
||||||
|
)
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from sqlalchemy import URL
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from store.persistence import MappedBase
|
||||||
|
from store.persistence.shared import close_in_order
|
||||||
|
from store.persistence.types import AppPersistence
|
||||||
|
|
||||||
|
|
||||||
|
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||||
|
return db_url.set(drivername="postgresql").render_as_string(hide_password=False)
|
||||||
|
|
||||||
|
|
||||||
|
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||||
|
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||||
|
|
||||||
|
import store.repositories.models # noqa: F401
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
db_url,
|
||||||
|
echo=echo,
|
||||||
|
future=True,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
pool_size=pool_size,
|
||||||
|
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
session_factory = async_sessionmaker(
|
||||||
|
bind=engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||||
|
checkpointer = await saver_cm.__aenter__()
|
||||||
|
|
||||||
|
async def setup() -> None:
|
||||||
|
# 1. LangGraph checkpoint tables / migrations
|
||||||
|
await checkpointer.setup()
|
||||||
|
|
||||||
|
# 2. ORM business tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(MappedBase.metadata.create_all)
|
||||||
|
|
||||||
|
async def _close_saver() -> None:
|
||||||
|
await saver_cm.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
async def aclose() -> None:
|
||||||
|
await close_in_order(
|
||||||
|
engine.dispose,
|
||||||
|
_close_saver,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AppPersistence(
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
engine=engine,
|
||||||
|
session_factory=session_factory,
|
||||||
|
setup=setup,
|
||||||
|
aclose=aclose,
|
||||||
|
)
|
||||||
@@ -0,0 +1,68 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
from sqlalchemy import URL, event
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from store.persistence import MappedBase
|
||||||
|
from store.persistence.shared import close_in_order
|
||||||
|
from store.persistence.types import AppPersistence
|
||||||
|
|
||||||
|
|
||||||
|
async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence:
|
||||||
|
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||||
|
|
||||||
|
import store.repositories.models # noqa: F401
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
db_url,
|
||||||
|
echo=echo,
|
||||||
|
future=True,
|
||||||
|
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||||
|
)
|
||||||
|
|
||||||
|
@event.listens_for(engine.sync_engine, "connect")
|
||||||
|
def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001
|
||||||
|
cursor = dbapi_conn.cursor()
|
||||||
|
try:
|
||||||
|
cursor.execute("PRAGMA journal_mode=WAL;")
|
||||||
|
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON;")
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
session_factory = async_sessionmaker(
|
||||||
|
bind=engine,
|
||||||
|
class_=AsyncSession,
|
||||||
|
expire_on_commit=False,
|
||||||
|
autoflush=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
|
||||||
|
checkpointer = await saver_cm.__aenter__()
|
||||||
|
|
||||||
|
async def setup() -> None:
|
||||||
|
# 1. LangGraph checkpoint tables
|
||||||
|
await checkpointer.setup()
|
||||||
|
|
||||||
|
# 2. ORM business tables
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(MappedBase.metadata.create_all)
|
||||||
|
|
||||||
|
async def _close_saver() -> None:
|
||||||
|
await saver_cm.__aexit__(None, None, None)
|
||||||
|
|
||||||
|
async def aclose() -> None:
|
||||||
|
await close_in_order(
|
||||||
|
engine.dispose,
|
||||||
|
_close_saver,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AppPersistence(
|
||||||
|
checkpointer=checkpointer,
|
||||||
|
engine=engine,
|
||||||
|
session_factory=session_factory,
|
||||||
|
setup=setup,
|
||||||
|
aclose=aclose,
|
||||||
|
)
|
||||||
@@ -0,0 +1,123 @@
|
|||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import URL
|
||||||
|
from sqlalchemy.engine.url import make_url
|
||||||
|
|
||||||
|
from store.common import DataBaseType
|
||||||
|
from store.config.app_config import get_app_config
|
||||||
|
from store.config.storage_config import StorageConfig
|
||||||
|
from store.persistence.types import AppPersistence
|
||||||
|
|
||||||
|
|
||||||
|
def storage_config_from_database_config(database_config: Any) -> StorageConfig:
|
||||||
|
"""Convert the existing public DatabaseConfig shape to StorageConfig.
|
||||||
|
|
||||||
|
Storage only owns durable database-backed persistence. The app bridge
|
||||||
|
should handle memory mode before calling into this package.
|
||||||
|
"""
|
||||||
|
backend = getattr(database_config, "backend", None)
|
||||||
|
if backend == "sqlite":
|
||||||
|
return StorageConfig(
|
||||||
|
driver="sqlite",
|
||||||
|
sqlite_dir=getattr(database_config, "sqlite_dir", ".deer-flow/data"),
|
||||||
|
echo_sql=getattr(database_config, "echo_sql", False),
|
||||||
|
pool_size=getattr(database_config, "pool_size", 5),
|
||||||
|
)
|
||||||
|
|
||||||
|
if backend == "postgres":
|
||||||
|
postgres_url = getattr(database_config, "postgres_url", "")
|
||||||
|
if not postgres_url:
|
||||||
|
raise ValueError("database.postgres_url is required when database.backend is 'postgres'")
|
||||||
|
parsed = make_url(postgres_url)
|
||||||
|
return StorageConfig(
|
||||||
|
driver="postgres",
|
||||||
|
database_url=postgres_url,
|
||||||
|
username=parsed.username or "",
|
||||||
|
password=parsed.password or "",
|
||||||
|
host=parsed.host or "localhost",
|
||||||
|
port=parsed.port or 5432,
|
||||||
|
db_name=parsed.database or "deerflow",
|
||||||
|
echo_sql=getattr(database_config, "echo_sql", False),
|
||||||
|
pool_size=getattr(database_config, "pool_size", 5),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported database backend for storage persistence: {backend!r}")
|
||||||
|
|
||||||
|
|
||||||
|
def _create_database_url(storage_config: StorageConfig) -> URL:
|
||||||
|
"""Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres)."""
|
||||||
|
|
||||||
|
if storage_config.driver == DataBaseType.sqlite:
|
||||||
|
driver = "sqlite+aiosqlite"
|
||||||
|
elif storage_config.driver == DataBaseType.mysql:
|
||||||
|
driver = "mysql+aiomysql"
|
||||||
|
elif storage_config.driver in (DataBaseType.postgresql, "postgres"):
|
||||||
|
driver = "postgresql+asyncpg"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported database driver: {storage_config.driver}")
|
||||||
|
|
||||||
|
if storage_config.driver == DataBaseType.sqlite:
|
||||||
|
import os
|
||||||
|
|
||||||
|
db_path = storage_config.sqlite_storage_path
|
||||||
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||||
|
|
||||||
|
url = URL.create(
|
||||||
|
drivername=driver,
|
||||||
|
database=db_path,
|
||||||
|
)
|
||||||
|
elif storage_config.database_url:
|
||||||
|
url = make_url(storage_config.database_url)
|
||||||
|
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
|
||||||
|
url = url.set(drivername="postgresql+asyncpg")
|
||||||
|
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
|
||||||
|
url = url.set(drivername="mysql+aiomysql")
|
||||||
|
else:
|
||||||
|
url = URL.create(
|
||||||
|
drivername=driver,
|
||||||
|
username=storage_config.username,
|
||||||
|
password=storage_config.password,
|
||||||
|
host=storage_config.host,
|
||||||
|
port=storage_config.port,
|
||||||
|
database=storage_config.db_name or "deerflow",
|
||||||
|
)
|
||||||
|
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
async def create_persistence_from_storage_config(storage_config: StorageConfig) -> AppPersistence:
|
||||||
|
from .drivers.mysql import build_mysql_persistence
|
||||||
|
from .drivers.postgres import build_postgres_persistence
|
||||||
|
from .drivers.sqlite import build_sqlite_persistence
|
||||||
|
|
||||||
|
driver = storage_config.driver
|
||||||
|
db_url = _create_database_url(storage_config)
|
||||||
|
|
||||||
|
if driver in ("postgres", "postgresql"):
|
||||||
|
return await build_postgres_persistence(
|
||||||
|
db_url,
|
||||||
|
echo=storage_config.echo_sql,
|
||||||
|
pool_size=storage_config.pool_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if driver == "mysql":
|
||||||
|
return await build_mysql_persistence(
|
||||||
|
db_url,
|
||||||
|
echo=storage_config.echo_sql,
|
||||||
|
pool_size=storage_config.pool_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if driver == "sqlite":
|
||||||
|
return await build_sqlite_persistence(db_url, echo=storage_config.echo_sql)
|
||||||
|
|
||||||
|
raise ValueError(f"Unsupported database driver: {driver}")
|
||||||
|
|
||||||
|
|
||||||
|
async def create_persistence_from_database_config(database_config: Any) -> AppPersistence:
|
||||||
|
storage_config = storage_config_from_database_config(database_config)
|
||||||
|
return await create_persistence_from_storage_config(storage_config)
|
||||||
|
|
||||||
|
|
||||||
|
async def create_persistence() -> AppPersistence:
|
||||||
|
app_config = get_app_config()
|
||||||
|
return await create_persistence_from_storage_config(app_config.storage)
|
||||||
@@ -0,0 +1,189 @@
|
|||||||
|
"""Dialect-aware JSON value matching for storage SQLAlchemy repositories."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import BigInteger, Float, String, bindparam
|
||||||
|
from sqlalchemy.ext.compiler import compiles
|
||||||
|
from sqlalchemy.sql.compiler import SQLCompiler
|
||||||
|
from sqlalchemy.sql.expression import ColumnElement
|
||||||
|
from sqlalchemy.sql.visitors import InternalTraversal
|
||||||
|
from sqlalchemy.types import Boolean, TypeEngine
|
||||||
|
|
||||||
|
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||||
|
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
|
||||||
|
|
||||||
|
_INT64_MIN = -(2**63)
|
||||||
|
_INT64_MAX = 2**63 - 1
|
||||||
|
|
||||||
|
|
||||||
|
def validate_metadata_filter_key(key: object) -> bool:
|
||||||
|
"""Return True when *key* is safe for JSON metadata filter SQL paths."""
|
||||||
|
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_metadata_filter_value(value: object) -> bool:
|
||||||
|
"""Return True when *value* can be compiled into a portable JSON predicate."""
|
||||||
|
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
|
||||||
|
return False
|
||||||
|
if isinstance(value, int) and not isinstance(value, bool):
|
||||||
|
return _INT64_MIN <= value <= _INT64_MAX
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class JsonMatch(ColumnElement[bool]):
|
||||||
|
"""Dialect-portable ``column[key] == value`` for JSON columns."""
|
||||||
|
|
||||||
|
inherit_cache = True
|
||||||
|
type = Boolean()
|
||||||
|
_is_implicitly_boolean = True
|
||||||
|
|
||||||
|
_traverse_internals = [
|
||||||
|
("column", InternalTraversal.dp_clauseelement),
|
||||||
|
("key", InternalTraversal.dp_string),
|
||||||
|
("value", InternalTraversal.dp_plain_obj),
|
||||||
|
("value_type", InternalTraversal.dp_string),
|
||||||
|
]
|
||||||
|
|
||||||
|
def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None:
|
||||||
|
if not validate_metadata_filter_key(key):
|
||||||
|
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
|
||||||
|
if not validate_metadata_filter_value(value):
|
||||||
|
if isinstance(value, int) and not isinstance(value, bool):
|
||||||
|
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
|
||||||
|
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
|
||||||
|
self.column = column
|
||||||
|
self.key = key
|
||||||
|
self.value = value
|
||||||
|
self.value_type = type(value).__qualname__
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class _Dialect:
|
||||||
|
null_type: str
|
||||||
|
num_types: tuple[str, ...]
|
||||||
|
num_cast: str
|
||||||
|
int_types: tuple[str, ...]
|
||||||
|
int_cast: str
|
||||||
|
int_guard: str | None
|
||||||
|
string_type: str
|
||||||
|
bool_type: str | None
|
||||||
|
true_value: str
|
||||||
|
false_value: str
|
||||||
|
|
||||||
|
|
||||||
|
_SQLITE = _Dialect(
|
||||||
|
null_type="null",
|
||||||
|
num_types=("integer", "real"),
|
||||||
|
num_cast="REAL",
|
||||||
|
int_types=("integer",),
|
||||||
|
int_cast="INTEGER",
|
||||||
|
int_guard=None,
|
||||||
|
string_type="text",
|
||||||
|
bool_type=None,
|
||||||
|
true_value="true",
|
||||||
|
false_value="false",
|
||||||
|
)
|
||||||
|
|
||||||
|
_POSTGRES = _Dialect(
|
||||||
|
null_type="null",
|
||||||
|
num_types=("number",),
|
||||||
|
num_cast="DOUBLE PRECISION",
|
||||||
|
int_types=("number",),
|
||||||
|
int_cast="BIGINT",
|
||||||
|
int_guard="'^-?[0-9]+$'",
|
||||||
|
string_type="string",
|
||||||
|
bool_type="boolean",
|
||||||
|
true_value="true",
|
||||||
|
false_value="false",
|
||||||
|
)
|
||||||
|
|
||||||
|
_MYSQL = _Dialect(
|
||||||
|
null_type="NULL",
|
||||||
|
num_types=("INTEGER", "DOUBLE", "DECIMAL"),
|
||||||
|
num_cast="DOUBLE",
|
||||||
|
int_types=("INTEGER",),
|
||||||
|
int_cast="SIGNED",
|
||||||
|
int_guard=None,
|
||||||
|
string_type="STRING",
|
||||||
|
bool_type="BOOLEAN",
|
||||||
|
true_value="true",
|
||||||
|
false_value="false",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
|
||||||
|
param = bindparam(None, value, type_=sa_type)
|
||||||
|
return compiler.process(param, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
|
||||||
|
if len(types) == 1:
|
||||||
|
return f"{typeof} = '{types[0]}'"
|
||||||
|
quoted = ", ".join(f"'{type_name}'" for type_name in types)
|
||||||
|
return f"{typeof} IN ({quoted})"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
|
||||||
|
if value is None:
|
||||||
|
return f"{typeof} = '{dialect.null_type}'"
|
||||||
|
if isinstance(value, bool):
|
||||||
|
bool_str = dialect.true_value if value else dialect.false_value
|
||||||
|
if dialect.bool_type is None:
|
||||||
|
return f"{typeof} = '{bool_str}'"
|
||||||
|
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
|
||||||
|
if isinstance(value, int):
|
||||||
|
bp = _bind(compiler, value, BigInteger(), **kw)
|
||||||
|
if dialect.int_guard:
|
||||||
|
return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
|
||||||
|
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
|
||||||
|
if isinstance(value, float):
|
||||||
|
bp = _bind(compiler, value, Float(), **kw)
|
||||||
|
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
|
||||||
|
bp = _bind(compiler, str(value), String(), **kw)
|
||||||
|
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(JsonMatch, "sqlite")
|
||||||
|
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||||
|
if not validate_metadata_filter_key(element.key):
|
||||||
|
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||||
|
col = compiler.process(element.column, **kw)
|
||||||
|
path = f'$."{element.key}"'
|
||||||
|
typeof = f"json_type({col}, '{path}')"
|
||||||
|
extract = f"json_extract({col}, '{path}')"
|
||||||
|
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(JsonMatch, "postgresql")
|
||||||
|
def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||||
|
if not validate_metadata_filter_key(element.key):
|
||||||
|
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||||
|
col = compiler.process(element.column, **kw)
|
||||||
|
typeof = f"json_typeof({col} -> '{element.key}')"
|
||||||
|
extract = f"({col} ->> '{element.key}')"
|
||||||
|
return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(JsonMatch, "mysql")
|
||||||
|
def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||||
|
if not validate_metadata_filter_key(element.key):
|
||||||
|
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||||
|
col = compiler.process(element.column, **kw)
|
||||||
|
path = f'$."{element.key}"'
|
||||||
|
typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))"
|
||||||
|
extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))"
|
||||||
|
return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw)
|
||||||
|
|
||||||
|
|
||||||
|
@compiles(JsonMatch)
|
||||||
|
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||||
|
raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}")
|
||||||
|
|
||||||
|
|
||||||
|
def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch:
|
||||||
|
return JsonMatch(column, key, value)
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .close import close_in_order
|
||||||
|
|
||||||
|
__all__ = ["close_in_order"]
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
|
||||||
|
AsyncCloser = Callable[[], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
async def close_in_order(*closers: AsyncCloser) -> None:
|
||||||
|
"""
|
||||||
|
Run async closers in order and raise the first error, if any.
|
||||||
|
|
||||||
|
Notes
|
||||||
|
-----
|
||||||
|
- Used to keep driver-specific close logic readable.
|
||||||
|
- We intentionally do not stop at first failure, so later resources
|
||||||
|
still get a chance to close.
|
||||||
|
"""
|
||||||
|
first_error: Exception | None = None
|
||||||
|
|
||||||
|
for closer in closers:
|
||||||
|
try:
|
||||||
|
await closer()
|
||||||
|
except Exception as exc:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = exc
|
||||||
|
|
||||||
|
if first_error is not None:
|
||||||
|
raise first_error
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Awaitable, Callable
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from langgraph.types import Checkpointer
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||||
|
|
||||||
|
AsyncSetup = Callable[[], Awaitable[None]]
|
||||||
|
AsyncClose = Callable[[], Awaitable[None]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(slots=True)
|
||||||
|
class AppPersistence:
|
||||||
|
"""
|
||||||
|
Unified runtime persistence bundle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
checkpointer: Checkpointer
|
||||||
|
engine: AsyncEngine
|
||||||
|
session_factory: async_sessionmaker[AsyncSession]
|
||||||
|
setup: AsyncSetup
|
||||||
|
aclose: AsyncClose
|
||||||
@@ -0,0 +1,53 @@
|
|||||||
|
from store.repositories.contracts import (
|
||||||
|
Feedback,
|
||||||
|
FeedbackAggregate,
|
||||||
|
FeedbackCreate,
|
||||||
|
FeedbackRepositoryProtocol,
|
||||||
|
InvalidMetadataFilterError,
|
||||||
|
Run,
|
||||||
|
RunCreate,
|
||||||
|
RunEvent,
|
||||||
|
RunEventCreate,
|
||||||
|
RunEventRepositoryProtocol,
|
||||||
|
RunRepositoryProtocol,
|
||||||
|
ThreadMeta,
|
||||||
|
ThreadMetaCreate,
|
||||||
|
ThreadMetaRepositoryProtocol,
|
||||||
|
User,
|
||||||
|
UserCreate,
|
||||||
|
UserNotFoundError,
|
||||||
|
UserRepositoryProtocol,
|
||||||
|
)
|
||||||
|
from store.repositories.factory import (
|
||||||
|
build_feedback_repository,
|
||||||
|
build_run_event_repository,
|
||||||
|
build_run_repository,
|
||||||
|
build_thread_meta_repository,
|
||||||
|
build_user_repository,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Feedback",
|
||||||
|
"FeedbackAggregate",
|
||||||
|
"FeedbackCreate",
|
||||||
|
"FeedbackRepositoryProtocol",
|
||||||
|
"InvalidMetadataFilterError",
|
||||||
|
"Run",
|
||||||
|
"RunCreate",
|
||||||
|
"RunEvent",
|
||||||
|
"RunEventCreate",
|
||||||
|
"RunEventRepositoryProtocol",
|
||||||
|
"RunRepositoryProtocol",
|
||||||
|
"ThreadMeta",
|
||||||
|
"ThreadMetaCreate",
|
||||||
|
"ThreadMetaRepositoryProtocol",
|
||||||
|
"User",
|
||||||
|
"UserCreate",
|
||||||
|
"UserNotFoundError",
|
||||||
|
"UserRepositoryProtocol",
|
||||||
|
"build_run_repository",
|
||||||
|
"build_run_event_repository",
|
||||||
|
"build_thread_meta_repository",
|
||||||
|
"build_feedback_repository",
|
||||||
|
"build_user_repository",
|
||||||
|
]
|
||||||
@@ -0,0 +1,49 @@
|
|||||||
|
from store.repositories.contracts.feedback import (
|
||||||
|
Feedback,
|
||||||
|
FeedbackAggregate,
|
||||||
|
FeedbackCreate,
|
||||||
|
FeedbackRepositoryProtocol,
|
||||||
|
)
|
||||||
|
from store.repositories.contracts.run import (
|
||||||
|
Run,
|
||||||
|
RunCreate,
|
||||||
|
RunRepositoryProtocol,
|
||||||
|
)
|
||||||
|
from store.repositories.contracts.run_event import (
|
||||||
|
RunEvent,
|
||||||
|
RunEventCreate,
|
||||||
|
RunEventRepositoryProtocol,
|
||||||
|
)
|
||||||
|
from store.repositories.contracts.thread_meta import (
|
||||||
|
InvalidMetadataFilterError,
|
||||||
|
ThreadMeta,
|
||||||
|
ThreadMetaCreate,
|
||||||
|
ThreadMetaRepositoryProtocol,
|
||||||
|
)
|
||||||
|
from store.repositories.contracts.user import (
|
||||||
|
User,
|
||||||
|
UserCreate,
|
||||||
|
UserNotFoundError,
|
||||||
|
UserRepositoryProtocol,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Feedback",
|
||||||
|
"FeedbackAggregate",
|
||||||
|
"FeedbackCreate",
|
||||||
|
"FeedbackRepositoryProtocol",
|
||||||
|
"Run",
|
||||||
|
"RunCreate",
|
||||||
|
"RunEvent",
|
||||||
|
"RunEventCreate",
|
||||||
|
"RunEventRepositoryProtocol",
|
||||||
|
"RunRepositoryProtocol",
|
||||||
|
"InvalidMetadataFilterError",
|
||||||
|
"ThreadMeta",
|
||||||
|
"ThreadMetaCreate",
|
||||||
|
"ThreadMetaRepositoryProtocol",
|
||||||
|
"User",
|
||||||
|
"UserCreate",
|
||||||
|
"UserNotFoundError",
|
||||||
|
"UserRepositoryProtocol",
|
||||||
|
]
|
||||||
@@ -0,0 +1,77 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Protocol, TypedDict
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackCreate(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
feedback_id: str
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
rating: int
|
||||||
|
user_id: str | None = None
|
||||||
|
message_id: str | None = None
|
||||||
|
comment: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Feedback(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
feedback_id: str
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
rating: int
|
||||||
|
user_id: str | None
|
||||||
|
message_id: str | None
|
||||||
|
comment: str | None
|
||||||
|
created_time: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackAggregate(TypedDict):
|
||||||
|
run_id: str
|
||||||
|
total: int
|
||||||
|
positive: int
|
||||||
|
negative: int
|
||||||
|
|
||||||
|
|
||||||
|
class FeedbackRepositoryProtocol(Protocol):
|
||||||
|
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_feedback_by_run(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> list[Feedback]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_feedback_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> list[Feedback]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,100 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RunCreate(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
assistant_id: str | None = None
|
||||||
|
user_id: str | None = None
|
||||||
|
status: str = "pending"
|
||||||
|
model_name: str | None = None
|
||||||
|
multitask_strategy: str = "reject"
|
||||||
|
error: str | None = None
|
||||||
|
follow_up_to_run_id: str | None = None
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
created_time: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class Run(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
run_id: str
|
||||||
|
thread_id: str
|
||||||
|
assistant_id: str | None
|
||||||
|
user_id: str | None
|
||||||
|
status: str
|
||||||
|
model_name: str | None
|
||||||
|
multitask_strategy: str
|
||||||
|
error: str | None
|
||||||
|
follow_up_to_run_id: str | None
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
kwargs: dict[str, Any]
|
||||||
|
total_input_tokens: int
|
||||||
|
total_output_tokens: int
|
||||||
|
total_tokens: int
|
||||||
|
llm_call_count: int
|
||||||
|
lead_agent_tokens: int
|
||||||
|
subagent_tokens: int
|
||||||
|
middleware_tokens: int
|
||||||
|
message_count: int
|
||||||
|
first_human_message: str | None
|
||||||
|
last_ai_message: str | None
|
||||||
|
created_time: datetime
|
||||||
|
updated_time: datetime | None
|
||||||
|
|
||||||
|
|
||||||
|
class RunRepositoryProtocol(Protocol):
|
||||||
|
async def create_run(self, data: RunCreate) -> Run:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_run(self, run_id: str) -> Run | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_runs_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[Run]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_run(self, run_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def update_run_completion(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
total_input_tokens: int = 0,
|
||||||
|
total_output_tokens: int = 0,
|
||||||
|
total_tokens: int = 0,
|
||||||
|
llm_call_count: int = 0,
|
||||||
|
lead_agent_tokens: int = 0,
|
||||||
|
subagent_tokens: int = 0,
|
||||||
|
middleware_tokens: int = 0,
|
||||||
|
message_count: int = 0,
|
||||||
|
first_human_message: str | None = None,
|
||||||
|
last_ai_message: str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class RunEventCreate(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
run_id: str
|
||||||
|
user_id: str | None = None
|
||||||
|
event_type: str
|
||||||
|
category: str
|
||||||
|
content: Any = ""
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
created_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RunEvent(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
run_id: str
|
||||||
|
user_id: str | None
|
||||||
|
event_type: str
|
||||||
|
category: str
|
||||||
|
content: Any
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
seq: int
|
||||||
|
created_at: datetime
|
||||||
|
|
||||||
|
|
||||||
|
class RunEventRepositoryProtocol(Protocol):
|
||||||
|
# Sequence values are time-ordered integer cursors. The application layer
|
||||||
|
# owns the single-writer invariant for a thread while a run is active.
|
||||||
|
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[RunEvent]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_events(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
event_types: list[str] | None = None,
|
||||||
|
limit: int = 500,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[RunEvent]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def list_messages_by_run(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[RunEvent]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidMetadataFilterError(ValueError):
|
||||||
|
"""Raised when all client-supplied metadata filters are rejected."""
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMetaCreate(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
assistant_id: str | None = None
|
||||||
|
user_id: str | None = None
|
||||||
|
display_name: str | None = None
|
||||||
|
status: str = "idle"
|
||||||
|
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMeta(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
thread_id: str
|
||||||
|
assistant_id: str | None
|
||||||
|
user_id: str | None
|
||||||
|
display_name: str | None
|
||||||
|
status: str
|
||||||
|
metadata: dict[str, Any]
|
||||||
|
created_time: datetime
|
||||||
|
updated_time: datetime | None
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMetaRepositoryProtocol(Protocol):
|
||||||
|
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def update_thread_meta(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
display_name: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def delete_thread(self, thread_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def search_threads(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[ThreadMeta]:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class UserNotFoundError(LookupError):
|
||||||
|
"""Raised when an update targets a user row that no longer exists."""
|
||||||
|
|
||||||
|
|
||||||
|
class UserCreate(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="forbid")
|
||||||
|
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
password_hash: str | None = None
|
||||||
|
system_role: Literal["admin", "user"] = "user"
|
||||||
|
created_at: datetime | None = None
|
||||||
|
oauth_provider: str | None = None
|
||||||
|
oauth_id: str | None = None
|
||||||
|
needs_setup: bool = False
|
||||||
|
token_version: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
model_config = ConfigDict(frozen=True)
|
||||||
|
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
password_hash: str | None
|
||||||
|
system_role: Literal["admin", "user"]
|
||||||
|
created_at: datetime
|
||||||
|
oauth_provider: str | None
|
||||||
|
oauth_id: str | None
|
||||||
|
needs_setup: bool
|
||||||
|
token_version: int
|
||||||
|
|
||||||
|
|
||||||
|
class UserRepositoryProtocol(Protocol):
|
||||||
|
async def create_user(self, data: UserCreate) -> User:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_user_by_email(self, email: str) -> User | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def get_first_admin(self) -> User | None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def update_user(self, data: User) -> User:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def count_users(self) -> int:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def count_admin_users(self) -> int:
|
||||||
|
pass
|
||||||
@@ -0,0 +1,13 @@
|
|||||||
|
from store.repositories.db.feedback import DbFeedbackRepository
|
||||||
|
from store.repositories.db.run import DbRunRepository
|
||||||
|
from store.repositories.db.run_event import DbRunEventRepository
|
||||||
|
from store.repositories.db.thread_meta import DbThreadMetaRepository
|
||||||
|
from store.repositories.db.user import DbUserRepository
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DbFeedbackRepository",
|
||||||
|
"DbRunRepository",
|
||||||
|
"DbRunEventRepository",
|
||||||
|
"DbThreadMetaRepository",
|
||||||
|
"DbUserRepository",
|
||||||
|
]
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from sqlalchemy import case, delete, func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from store.repositories.contracts.feedback import Feedback, FeedbackAggregate, FeedbackCreate, FeedbackRepositoryProtocol
|
||||||
|
from store.repositories.models.feedback import Feedback as FeedbackModel
|
||||||
|
|
||||||
|
|
||||||
|
def _to_feedback(m: FeedbackModel) -> Feedback:
|
||||||
|
return Feedback(
|
||||||
|
feedback_id=m.feedback_id,
|
||||||
|
run_id=m.run_id,
|
||||||
|
thread_id=m.thread_id,
|
||||||
|
rating=m.rating,
|
||||||
|
user_id=m.user_id,
|
||||||
|
message_id=m.message_id,
|
||||||
|
comment=m.comment,
|
||||||
|
created_time=m.created_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||||
|
if data.rating not in (1, -1):
|
||||||
|
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||||
|
model = FeedbackModel(
|
||||||
|
feedback_id=data.feedback_id,
|
||||||
|
run_id=data.run_id,
|
||||||
|
thread_id=data.thread_id,
|
||||||
|
rating=data.rating,
|
||||||
|
user_id=data.user_id,
|
||||||
|
message_id=data.message_id,
|
||||||
|
comment=data.comment,
|
||||||
|
)
|
||||||
|
self._session.add(model)
|
||||||
|
await self._session.flush()
|
||||||
|
await self._session.refresh(model)
|
||||||
|
return _to_feedback(model)
|
||||||
|
|
||||||
|
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||||
|
if data.rating not in (1, -1):
|
||||||
|
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||||
|
|
||||||
|
result = await self._session.execute(
|
||||||
|
select(FeedbackModel).where(
|
||||||
|
FeedbackModel.thread_id == data.thread_id,
|
||||||
|
FeedbackModel.run_id == data.run_id,
|
||||||
|
FeedbackModel.user_id == data.user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
if model is None:
|
||||||
|
return await self.create_feedback(data)
|
||||||
|
|
||||||
|
model.rating = data.rating
|
||||||
|
model.message_id = data.message_id
|
||||||
|
model.comment = data.comment
|
||||||
|
model.created_time = datetime.now(UTC)
|
||||||
|
await self._session.flush()
|
||||||
|
await self._session.refresh(model)
|
||||||
|
return _to_feedback(model)
|
||||||
|
|
||||||
|
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||||
|
result = await self._session.execute(select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_feedback(model) if model else None
|
||||||
|
|
||||||
|
async def list_feedback_by_run(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
thread_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> list[Feedback]:
|
||||||
|
stmt = select(FeedbackModel).where(FeedbackModel.run_id == run_id)
|
||||||
|
if thread_id is not None:
|
||||||
|
stmt = stmt.where(FeedbackModel.thread_id == thread_id)
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||||
|
stmt = stmt.order_by(FeedbackModel.created_time.desc())
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return [_to_feedback(m) for m in result.scalars().all()]
|
||||||
|
|
||||||
|
async def list_feedback_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> list[Feedback]:
|
||||||
|
stmt = select(FeedbackModel).where(FeedbackModel.thread_id == thread_id)
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||||
|
stmt = stmt.order_by(FeedbackModel.created_time.desc())
|
||||||
|
if limit is not None:
|
||||||
|
stmt = stmt.limit(limit)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return [_to_feedback(m) for m in result.scalars().all()]
|
||||||
|
|
||||||
|
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||||
|
existing = await self.get_feedback(feedback_id)
|
||||||
|
if existing is None:
|
||||||
|
return False
|
||||||
|
await self._session.execute(delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||||
|
stmt = select(FeedbackModel).where(
|
||||||
|
FeedbackModel.thread_id == thread_id,
|
||||||
|
FeedbackModel.run_id == run_id,
|
||||||
|
)
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
if model is None:
|
||||||
|
return False
|
||||||
|
await self._session.delete(model)
|
||||||
|
return True
|
||||||
|
|
||||||
|
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
|
||||||
|
stmt = select(
|
||||||
|
func.count().label("total"),
|
||||||
|
func.coalesce(func.sum(case((FeedbackModel.rating == 1, 1), else_=0)), 0).label("positive"),
|
||||||
|
func.coalesce(func.sum(case((FeedbackModel.rating == -1, 1), else_=0)), 0).label("negative"),
|
||||||
|
).where(FeedbackModel.thread_id == thread_id, FeedbackModel.run_id == run_id)
|
||||||
|
row = (await self._session.execute(stmt)).one()
|
||||||
|
return {
|
||||||
|
"run_id": run_id,
|
||||||
|
"total": int(row.total),
|
||||||
|
"positive": int(row.positive),
|
||||||
|
"negative": int(row.negative),
|
||||||
|
}
|
||||||
@@ -0,0 +1,185 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import delete, func, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol
|
||||||
|
from store.repositories.models.run import Run as RunModel
|
||||||
|
|
||||||
|
|
||||||
|
def _to_run(m: RunModel) -> Run:
|
||||||
|
return Run(
|
||||||
|
run_id=m.run_id,
|
||||||
|
thread_id=m.thread_id,
|
||||||
|
assistant_id=m.assistant_id,
|
||||||
|
user_id=m.user_id,
|
||||||
|
status=m.status,
|
||||||
|
model_name=m.model_name,
|
||||||
|
multitask_strategy=m.multitask_strategy,
|
||||||
|
error=m.error,
|
||||||
|
follow_up_to_run_id=m.follow_up_to_run_id,
|
||||||
|
metadata=dict(m.meta or {}),
|
||||||
|
kwargs=dict(m.kwargs or {}),
|
||||||
|
total_input_tokens=m.total_input_tokens,
|
||||||
|
total_output_tokens=m.total_output_tokens,
|
||||||
|
total_tokens=m.total_tokens,
|
||||||
|
llm_call_count=m.llm_call_count,
|
||||||
|
lead_agent_tokens=m.lead_agent_tokens,
|
||||||
|
subagent_tokens=m.subagent_tokens,
|
||||||
|
middleware_tokens=m.middleware_tokens,
|
||||||
|
message_count=m.message_count,
|
||||||
|
first_human_message=m.first_human_message,
|
||||||
|
last_ai_message=m.last_ai_message,
|
||||||
|
created_time=m.created_time,
|
||||||
|
updated_time=m.updated_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DbRunRepository(RunRepositoryProtocol):
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def create_run(self, data: RunCreate) -> Run:
|
||||||
|
model = RunModel(
|
||||||
|
run_id=data.run_id,
|
||||||
|
thread_id=data.thread_id,
|
||||||
|
assistant_id=data.assistant_id,
|
||||||
|
user_id=data.user_id,
|
||||||
|
status=data.status,
|
||||||
|
model_name=data.model_name,
|
||||||
|
multitask_strategy=data.multitask_strategy,
|
||||||
|
error=data.error,
|
||||||
|
follow_up_to_run_id=data.follow_up_to_run_id,
|
||||||
|
meta=dict(data.metadata),
|
||||||
|
kwargs=dict(data.kwargs),
|
||||||
|
)
|
||||||
|
if data.created_time is not None:
|
||||||
|
model.created_time = data.created_time
|
||||||
|
self._session.add(model)
|
||||||
|
await self._session.flush()
|
||||||
|
await self._session.refresh(model)
|
||||||
|
return _to_run(model)
|
||||||
|
|
||||||
|
async def get_run(self, run_id: str) -> Run | None:
|
||||||
|
result = await self._session.execute(select(RunModel).where(RunModel.run_id == run_id))
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_run(model) if model else None
|
||||||
|
|
||||||
|
async def list_runs_by_thread(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
limit: int = 50,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[Run]:
|
||||||
|
stmt = select(RunModel).where(RunModel.thread_id == thread_id)
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(RunModel.user_id == user_id)
|
||||||
|
stmt = stmt.order_by(RunModel.created_time.desc()).limit(limit).offset(offset)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return [_to_run(m) for m in result.scalars().all()]
|
||||||
|
|
||||||
|
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
|
||||||
|
values: dict = {"status": status}
|
||||||
|
if error is not None:
|
||||||
|
values["error"] = error
|
||||||
|
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||||
|
|
||||||
|
async def delete_run(self, run_id: str) -> None:
|
||||||
|
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
|
||||||
|
|
||||||
|
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
|
||||||
|
if before is None:
|
||||||
|
before_dt = datetime.now().astimezone()
|
||||||
|
elif isinstance(before, datetime):
|
||||||
|
before_dt = before
|
||||||
|
else:
|
||||||
|
before_dt = datetime.fromisoformat(before)
|
||||||
|
|
||||||
|
result = await self._session.execute(select(RunModel).where(RunModel.status == "pending", RunModel.created_time <= before_dt).order_by(RunModel.created_time.asc()))
|
||||||
|
return [_to_run(m) for m in result.scalars().all()]
|
||||||
|
|
||||||
|
async def update_run_completion(
|
||||||
|
self,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
total_input_tokens: int = 0,
|
||||||
|
total_output_tokens: int = 0,
|
||||||
|
total_tokens: int = 0,
|
||||||
|
llm_call_count: int = 0,
|
||||||
|
lead_agent_tokens: int = 0,
|
||||||
|
subagent_tokens: int = 0,
|
||||||
|
middleware_tokens: int = 0,
|
||||||
|
message_count: int = 0,
|
||||||
|
first_human_message: str | None = None,
|
||||||
|
last_ai_message: str | None = None,
|
||||||
|
error: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
values = {
|
||||||
|
"status": status,
|
||||||
|
"total_input_tokens": total_input_tokens,
|
||||||
|
"total_output_tokens": total_output_tokens,
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"llm_call_count": llm_call_count,
|
||||||
|
"lead_agent_tokens": lead_agent_tokens,
|
||||||
|
"subagent_tokens": subagent_tokens,
|
||||||
|
"middleware_tokens": middleware_tokens,
|
||||||
|
"message_count": message_count,
|
||||||
|
}
|
||||||
|
if first_human_message is not None:
|
||||||
|
values["first_human_message"] = first_human_message[:2000]
|
||||||
|
if last_ai_message is not None:
|
||||||
|
values["last_ai_message"] = last_ai_message[:2000]
|
||||||
|
if error is not None:
|
||||||
|
values["error"] = error
|
||||||
|
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||||
|
|
||||||
|
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||||
|
completed = RunModel.status.in_(("success", "error"))
|
||||||
|
model_expr = func.coalesce(RunModel.model_name, "unknown")
|
||||||
|
stmt = (
|
||||||
|
select(
|
||||||
|
model_expr.label("model"),
|
||||||
|
func.count().label("runs"),
|
||||||
|
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
|
||||||
|
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
|
||||||
|
func.coalesce(func.sum(RunModel.total_output_tokens), 0).label("total_output_tokens"),
|
||||||
|
func.coalesce(func.sum(RunModel.lead_agent_tokens), 0).label("lead_agent"),
|
||||||
|
func.coalesce(func.sum(RunModel.subagent_tokens), 0).label("subagent"),
|
||||||
|
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
|
||||||
|
)
|
||||||
|
.where(RunModel.thread_id == thread_id, completed)
|
||||||
|
.group_by(model_expr)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = (await self._session.execute(stmt)).all()
|
||||||
|
total_tokens = total_input = total_output = total_runs = 0
|
||||||
|
lead_agent = subagent = middleware = 0
|
||||||
|
by_model: dict[str, dict] = {}
|
||||||
|
for row in rows:
|
||||||
|
by_model[row.model] = {"tokens": row.total_tokens, "runs": row.runs}
|
||||||
|
total_tokens += row.total_tokens
|
||||||
|
total_input += row.total_input_tokens
|
||||||
|
total_output += row.total_output_tokens
|
||||||
|
total_runs += row.runs
|
||||||
|
lead_agent += row.lead_agent
|
||||||
|
subagent += row.subagent
|
||||||
|
middleware += row.middleware
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_tokens": total_tokens,
|
||||||
|
"total_input_tokens": total_input,
|
||||||
|
"total_output_tokens": total_output,
|
||||||
|
"total_runs": total_runs,
|
||||||
|
"by_model": by_model,
|
||||||
|
"by_caller": {
|
||||||
|
"lead_agent": lead_agent,
|
||||||
|
"subagent": subagent,
|
||||||
|
"middleware": middleware,
|
||||||
|
},
|
||||||
|
}
|
||||||
@@ -0,0 +1,207 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import secrets
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import delete, func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
||||||
|
from store.repositories.models.run_event import RunEvent as RunEventModel
|
||||||
|
|
||||||
|
_SEQ_COUNTER_BITS = 12
|
||||||
|
_SEQ_PROCESS_BITS = 9
|
||||||
|
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
|
||||||
|
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
|
||||||
|
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
|
||||||
|
|
||||||
|
|
||||||
|
class _SequenceAllocator:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._last_millis = 0
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def allocate_base(self, batch_size: int) -> int:
|
||||||
|
if batch_size >= _SEQ_COUNTER_LIMIT:
|
||||||
|
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
|
||||||
|
|
||||||
|
now_ms = time.time_ns() // 1_000_000
|
||||||
|
with self._lock:
|
||||||
|
seq_ms = max(now_ms, self._last_millis + 1)
|
||||||
|
self._last_millis = seq_ms
|
||||||
|
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
|
||||||
|
|
||||||
|
|
||||||
|
_sequence_allocator = _SequenceAllocator()
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||||
|
if not isinstance(content, str):
|
||||||
|
next_metadata = {**metadata, "content_is_json": True}
|
||||||
|
if isinstance(content, dict):
|
||||||
|
next_metadata["content_is_dict"] = True
|
||||||
|
return json.dumps(content, default=str, ensure_ascii=False), next_metadata
|
||||||
|
return content, metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _deserialize_content(content: str, metadata: dict[str, Any]) -> Any:
|
||||||
|
if not (metadata.get("content_is_json") or metadata.get("content_is_dict")):
|
||||||
|
return content
|
||||||
|
try:
|
||||||
|
return json.loads(content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def _to_run_event(model: RunEventModel) -> RunEvent:
|
||||||
|
raw_metadata = dict(model.meta or {})
|
||||||
|
metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"}
|
||||||
|
return RunEvent(
|
||||||
|
thread_id=model.thread_id,
|
||||||
|
run_id=model.run_id,
|
||||||
|
user_id=model.user_id,
|
||||||
|
event_type=model.event_type,
|
||||||
|
category=model.category,
|
||||||
|
content=_deserialize_content(model.content, raw_metadata),
|
||||||
|
metadata=metadata,
|
||||||
|
seq=model.seq,
|
||||||
|
created_at=model.created_at,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||||
|
if not events:
|
||||||
|
return []
|
||||||
|
|
||||||
|
seq_base = _sequence_allocator.allocate_base(len(events))
|
||||||
|
|
||||||
|
rows: list[RunEventModel] = []
|
||||||
|
|
||||||
|
for index, event in enumerate(events, start=1):
|
||||||
|
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
||||||
|
row = RunEventModel(
|
||||||
|
thread_id=event.thread_id,
|
||||||
|
run_id=event.run_id,
|
||||||
|
user_id=event.user_id,
|
||||||
|
seq=seq_base + index,
|
||||||
|
event_type=event.event_type,
|
||||||
|
category=event.category,
|
||||||
|
content=content,
|
||||||
|
meta=metadata,
|
||||||
|
)
|
||||||
|
if event.created_at is not None:
|
||||||
|
row.created_at = event.created_at
|
||||||
|
self._session.add(row)
|
||||||
|
rows.append(row)
|
||||||
|
|
||||||
|
await self._session.flush()
|
||||||
|
return [_to_run_event(row) for row in rows]
|
||||||
|
|
||||||
|
async def list_messages(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[RunEvent]:
|
||||||
|
stmt = select(RunEventModel).where(
|
||||||
|
RunEventModel.thread_id == thread_id,
|
||||||
|
RunEventModel.category == "message",
|
||||||
|
)
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||||
|
if before_seq is not None:
|
||||||
|
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||||
|
if after_seq is not None:
|
||||||
|
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return [_to_run_event(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
|
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||||
|
|
||||||
|
async def list_events(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
event_types: list[str] | None = None,
|
||||||
|
limit: int = 500,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[RunEvent]:
|
||||||
|
stmt = select(RunEventModel).where(
|
||||||
|
RunEventModel.thread_id == thread_id,
|
||||||
|
RunEventModel.run_id == run_id,
|
||||||
|
)
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||||
|
if event_types is not None:
|
||||||
|
stmt = stmt.where(RunEventModel.event_type.in_(event_types))
|
||||||
|
stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return [_to_run_event(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
|
async def list_messages_by_run(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
run_id: str,
|
||||||
|
*,
|
||||||
|
limit: int = 50,
|
||||||
|
before_seq: int | None = None,
|
||||||
|
after_seq: int | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> list[RunEvent]:
|
||||||
|
stmt = select(RunEventModel).where(
|
||||||
|
RunEventModel.thread_id == thread_id,
|
||||||
|
RunEventModel.run_id == run_id,
|
||||||
|
RunEventModel.category == "message",
|
||||||
|
)
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||||
|
if before_seq is not None:
|
||||||
|
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||||
|
if after_seq is not None:
|
||||||
|
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return [_to_run_event(row) for row in result.scalars().all()]
|
||||||
|
|
||||||
|
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||||
|
|
||||||
|
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||||
|
stmt = select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||||
|
count = await self._session.scalar(stmt)
|
||||||
|
return int(count or 0)
|
||||||
|
|
||||||
|
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||||
|
conditions = [RunEventModel.thread_id == thread_id]
|
||||||
|
if user_id is not None:
|
||||||
|
conditions.append(RunEventModel.user_id == user_id)
|
||||||
|
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
|
||||||
|
await self._session.execute(delete(RunEventModel).where(*conditions))
|
||||||
|
return int(count or 0)
|
||||||
|
|
||||||
|
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
|
||||||
|
conditions = [RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id]
|
||||||
|
if user_id is not None:
|
||||||
|
conditions.append(RunEventModel.user_id == user_id)
|
||||||
|
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
|
||||||
|
await self._session.execute(delete(RunEventModel).where(*conditions))
|
||||||
|
return int(count or 0)
|
||||||
@@ -0,0 +1,113 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import delete, select, update
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from store.persistence.json_compat import json_match
|
||||||
|
from store.repositories.contracts.thread_meta import (
|
||||||
|
InvalidMetadataFilterError,
|
||||||
|
ThreadMeta,
|
||||||
|
ThreadMetaCreate,
|
||||||
|
ThreadMetaRepositoryProtocol,
|
||||||
|
)
|
||||||
|
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||||
|
return ThreadMeta(
|
||||||
|
thread_id=m.thread_id,
|
||||||
|
assistant_id=m.assistant_id,
|
||||||
|
user_id=m.user_id,
|
||||||
|
display_name=m.display_name,
|
||||||
|
status=m.status,
|
||||||
|
metadata=dict(m.meta or {}),
|
||||||
|
created_time=m.created_time,
|
||||||
|
updated_time=m.updated_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||||
|
model = ThreadMetaModel(
|
||||||
|
thread_id=data.thread_id,
|
||||||
|
assistant_id=data.assistant_id,
|
||||||
|
user_id=data.user_id,
|
||||||
|
display_name=data.display_name,
|
||||||
|
status=data.status,
|
||||||
|
meta=dict(data.metadata),
|
||||||
|
)
|
||||||
|
self._session.add(model)
|
||||||
|
await self._session.flush()
|
||||||
|
await self._session.refresh(model)
|
||||||
|
return _to_thread_meta(model)
|
||||||
|
|
||||||
|
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||||
|
result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_thread_meta(model) if model else None
|
||||||
|
|
||||||
|
async def update_thread_meta(
|
||||||
|
self,
|
||||||
|
thread_id: str,
|
||||||
|
*,
|
||||||
|
display_name: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> None:
|
||||||
|
values: dict = {}
|
||||||
|
if display_name is not None:
|
||||||
|
values["display_name"] = display_name
|
||||||
|
if status is not None:
|
||||||
|
values["status"] = status
|
||||||
|
if metadata is not None:
|
||||||
|
values["meta"] = dict(metadata)
|
||||||
|
if not values:
|
||||||
|
return
|
||||||
|
await self._session.execute(update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||||
|
|
||||||
|
async def delete_thread(self, thread_id: str) -> None:
|
||||||
|
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||||
|
|
||||||
|
async def search_threads(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
assistant_id: str | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
offset: int = 0,
|
||||||
|
) -> list[ThreadMeta]:
|
||||||
|
stmt = select(ThreadMetaModel)
|
||||||
|
|
||||||
|
if status is not None:
|
||||||
|
stmt = stmt.where(ThreadMetaModel.status == status)
|
||||||
|
if user_id is not None:
|
||||||
|
stmt = stmt.where(ThreadMetaModel.user_id == user_id)
|
||||||
|
if assistant_id is not None:
|
||||||
|
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||||
|
if metadata:
|
||||||
|
applied = 0
|
||||||
|
for key, value in metadata.items():
|
||||||
|
try:
|
||||||
|
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
|
||||||
|
applied += 1
|
||||||
|
except (ValueError, TypeError) as exc:
|
||||||
|
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
|
||||||
|
if applied == 0:
|
||||||
|
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
|
||||||
|
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
|
||||||
|
|
||||||
|
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
|
||||||
|
stmt = stmt.limit(limit).offset(offset)
|
||||||
|
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
return [_to_thread_meta(m) for m in result.scalars().all()]
|
||||||
@@ -0,0 +1,98 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from store.repositories.contracts.user import User, UserCreate, UserNotFoundError, UserRepositoryProtocol
|
||||||
|
from store.repositories.models.user import User as UserModel
|
||||||
|
|
||||||
|
|
||||||
|
def _to_user(model: UserModel) -> User:
|
||||||
|
return User(
|
||||||
|
id=model.id,
|
||||||
|
email=model.email,
|
||||||
|
password_hash=model.password_hash,
|
||||||
|
system_role=model.system_role, # type: ignore[arg-type]
|
||||||
|
created_at=model.created_at,
|
||||||
|
oauth_provider=model.oauth_provider,
|
||||||
|
oauth_id=model.oauth_id,
|
||||||
|
needs_setup=model.needs_setup,
|
||||||
|
token_version=model.token_version,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class DbUserRepository(UserRepositoryProtocol):
|
||||||
|
def __init__(self, session: AsyncSession) -> None:
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
async def create_user(self, data: UserCreate) -> User:
|
||||||
|
model = UserModel(
|
||||||
|
id=data.id,
|
||||||
|
email=data.email,
|
||||||
|
system_role=data.system_role,
|
||||||
|
password_hash=data.password_hash,
|
||||||
|
oauth_provider=data.oauth_provider,
|
||||||
|
oauth_id=data.oauth_id,
|
||||||
|
needs_setup=data.needs_setup,
|
||||||
|
token_version=data.token_version,
|
||||||
|
)
|
||||||
|
if data.created_at is not None:
|
||||||
|
model.created_at = data.created_at
|
||||||
|
self._session.add(model)
|
||||||
|
try:
|
||||||
|
await self._session.flush()
|
||||||
|
except IntegrityError as exc:
|
||||||
|
await self._session.rollback()
|
||||||
|
raise ValueError(f"Email already registered: {data.email}") from exc
|
||||||
|
await self._session.refresh(model)
|
||||||
|
return _to_user(model)
|
||||||
|
|
||||||
|
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||||
|
model = await self._session.get(UserModel, user_id)
|
||||||
|
return _to_user(model) if model is not None else None
|
||||||
|
|
||||||
|
async def get_user_by_email(self, email: str) -> User | None:
|
||||||
|
result = await self._session.execute(select(UserModel).where(UserModel.email == email))
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_user(model) if model is not None else None
|
||||||
|
|
||||||
|
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||||
|
result = await self._session.execute(
|
||||||
|
select(UserModel).where(
|
||||||
|
UserModel.oauth_provider == provider,
|
||||||
|
UserModel.oauth_id == oauth_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_user(model) if model is not None else None
|
||||||
|
|
||||||
|
async def get_first_admin(self) -> User | None:
|
||||||
|
result = await self._session.execute(select(UserModel).where(UserModel.system_role == "admin").limit(1))
|
||||||
|
model = result.scalar_one_or_none()
|
||||||
|
return _to_user(model) if model is not None else None
|
||||||
|
|
||||||
|
async def update_user(self, data: User) -> User:
|
||||||
|
model = await self._session.get(UserModel, data.id)
|
||||||
|
if model is None:
|
||||||
|
raise UserNotFoundError(f"User {data.id} no longer exists")
|
||||||
|
|
||||||
|
model.email = data.email
|
||||||
|
model.password_hash = data.password_hash
|
||||||
|
model.system_role = data.system_role
|
||||||
|
model.oauth_provider = data.oauth_provider
|
||||||
|
model.oauth_id = data.oauth_id
|
||||||
|
model.needs_setup = data.needs_setup
|
||||||
|
model.token_version = data.token_version
|
||||||
|
|
||||||
|
await self._session.flush()
|
||||||
|
await self._session.refresh(model)
|
||||||
|
return _to_user(model)
|
||||||
|
|
||||||
|
async def count_users(self) -> int:
|
||||||
|
count = await self._session.scalar(select(func.count()).select_from(UserModel))
|
||||||
|
return int(count or 0)
|
||||||
|
|
||||||
|
async def count_admin_users(self) -> int:
|
||||||
|
count = await self._session.scalar(select(func.count()).select_from(UserModel).where(UserModel.system_role == "admin"))
|
||||||
|
return int(count or 0)
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from store.repositories import (
|
||||||
|
FeedbackRepositoryProtocol,
|
||||||
|
RunEventRepositoryProtocol,
|
||||||
|
RunRepositoryProtocol,
|
||||||
|
ThreadMetaRepositoryProtocol,
|
||||||
|
UserRepositoryProtocol,
|
||||||
|
)
|
||||||
|
from store.repositories.db import (
|
||||||
|
DbFeedbackRepository,
|
||||||
|
DbRunEventRepository,
|
||||||
|
DbRunRepository,
|
||||||
|
DbThreadMetaRepository,
|
||||||
|
DbUserRepository,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_thread_meta_repository(session: AsyncSession) -> ThreadMetaRepositoryProtocol:
|
||||||
|
return DbThreadMetaRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
def build_run_repository(session: AsyncSession) -> RunRepositoryProtocol:
|
||||||
|
return DbRunRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
def build_feedback_repository(session: AsyncSession) -> FeedbackRepositoryProtocol:
|
||||||
|
return DbFeedbackRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
def build_run_event_repository(session: AsyncSession) -> RunEventRepositoryProtocol:
|
||||||
|
return DbRunEventRepository(session)
|
||||||
|
|
||||||
|
|
||||||
|
def build_user_repository(session: AsyncSession) -> UserRepositoryProtocol:
|
||||||
|
return DbUserRepository(session)
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from store.repositories.models.feedback import Feedback
|
||||||
|
from store.repositories.models.run import Run
|
||||||
|
from store.repositories.models.run_event import RunEvent
|
||||||
|
from store.repositories.models.thread_meta import ThreadMeta
|
||||||
|
from store.repositories.models.user import User
|
||||||
|
|
||||||
|
__all__ = ["Feedback", "Run", "RunEvent", "ThreadMeta", "User"]
|
||||||
@@ -0,0 +1,36 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Integer, String, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||||
|
|
||||||
|
|
||||||
|
class Feedback(DataClassBase):
|
||||||
|
"""Feedback table (create-only, no updated_time)."""
|
||||||
|
|
||||||
|
__tablename__ = "feedback"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
||||||
|
{"comment": "Feedback table."},
|
||||||
|
)
|
||||||
|
|
||||||
|
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||||
|
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||||
|
rating: Mapped[int] = mapped_column(Integer)
|
||||||
|
|
||||||
|
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||||
|
message_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||||
|
comment: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||||
|
|
||||||
|
created_time: Mapped[datetime] = mapped_column(
|
||||||
|
"created_at",
|
||||||
|
TimeZone,
|
||||||
|
init=False,
|
||||||
|
default_factory=current_time,
|
||||||
|
sort_order=999,
|
||||||
|
comment="Created at",
|
||||||
|
)
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import JSON, Index, Integer, String
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||||
|
|
||||||
|
|
||||||
|
class Run(DataClassBase):
|
||||||
|
"""Run metadata table."""
|
||||||
|
|
||||||
|
__tablename__ = "runs"
|
||||||
|
__table_args__ = (
|
||||||
|
Index("ix_runs_thread_status", "thread_id", "status"),
|
||||||
|
{"comment": "Run metadata table."},
|
||||||
|
)
|
||||||
|
|
||||||
|
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||||
|
|
||||||
|
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||||
|
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
|
||||||
|
model_name: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||||
|
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
|
||||||
|
error: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||||
|
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||||
|
|
||||||
|
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
|
||||||
|
kwargs: Mapped[dict[str, Any]] = mapped_column("kwargs_json", JSON, default_factory=dict)
|
||||||
|
|
||||||
|
total_input_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
total_output_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
llm_call_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
lead_agent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
subagent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
middleware_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
|
||||||
|
message_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
first_human_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||||
|
last_ai_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||||
|
|
||||||
|
created_time: Mapped[datetime] = mapped_column(
|
||||||
|
"created_at",
|
||||||
|
TimeZone,
|
||||||
|
init=False,
|
||||||
|
default_factory=current_time,
|
||||||
|
sort_order=999,
|
||||||
|
comment="Created at",
|
||||||
|
)
|
||||||
|
updated_time: Mapped[datetime | None] = mapped_column(
|
||||||
|
"updated_at",
|
||||||
|
TimeZone,
|
||||||
|
init=False,
|
||||||
|
default=None,
|
||||||
|
onupdate=current_time,
|
||||||
|
sort_order=999,
|
||||||
|
comment="Updated at",
|
||||||
|
)
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from store.persistence.base_model import (
|
||||||
|
DataClassBase,
|
||||||
|
TimeZone,
|
||||||
|
UniversalText,
|
||||||
|
current_time,
|
||||||
|
id_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RunEvent(DataClassBase):
|
||||||
|
"""Run event table."""
|
||||||
|
|
||||||
|
__tablename__ = "run_events"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
|
||||||
|
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||||
|
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
||||||
|
{"comment": "Run event table."},
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[id_key] = mapped_column(init=False)
|
||||||
|
|
||||||
|
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||||
|
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||||
|
event_type: Mapped[str] = mapped_column(String(32), index=True)
|
||||||
|
category: Mapped[str] = mapped_column(String(16), index=True)
|
||||||
|
|
||||||
|
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||||
|
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
|
||||||
|
content: Mapped[str] = mapped_column(UniversalText, default="")
|
||||||
|
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
TimeZone,
|
||||||
|
init=False,
|
||||||
|
default_factory=current_time,
|
||||||
|
sort_order=999,
|
||||||
|
comment="Event timestamp",
|
||||||
|
)
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import JSON, String
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||||
|
|
||||||
|
|
||||||
|
class ThreadMeta(DataClassBase):
|
||||||
|
"""Thread metadata table."""
|
||||||
|
|
||||||
|
__tablename__ = "threads_meta"
|
||||||
|
__table_args__ = {"comment": "Thread metadata table."}
|
||||||
|
|
||||||
|
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
|
||||||
|
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None, index=True)
|
||||||
|
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||||
|
display_name: Mapped[str | None] = mapped_column(String(256), default=None)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), default="idle", index=True)
|
||||||
|
|
||||||
|
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
|
||||||
|
|
||||||
|
created_time: Mapped[datetime] = mapped_column(
|
||||||
|
"created_at",
|
||||||
|
TimeZone,
|
||||||
|
init=False,
|
||||||
|
default_factory=current_time,
|
||||||
|
sort_order=999,
|
||||||
|
comment="Created at",
|
||||||
|
)
|
||||||
|
updated_time: Mapped[datetime | None] = mapped_column(
|
||||||
|
"updated_at",
|
||||||
|
TimeZone,
|
||||||
|
init=False,
|
||||||
|
default=None,
|
||||||
|
onupdate=current_time,
|
||||||
|
sort_order=999,
|
||||||
|
comment="Updated at",
|
||||||
|
)
|
||||||
@@ -0,0 +1,42 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, Index, String, text
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||||
|
|
||||||
|
|
||||||
|
class User(DataClassBase):
|
||||||
|
"""User account table."""
|
||||||
|
|
||||||
|
__tablename__ = "users"
|
||||||
|
__table_args__ = (
|
||||||
|
Index(
|
||||||
|
"idx_users_oauth_identity",
|
||||||
|
"oauth_provider",
|
||||||
|
"oauth_id",
|
||||||
|
unique=True,
|
||||||
|
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
|
||||||
|
),
|
||||||
|
{"comment": "User account table."},
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||||
|
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
|
||||||
|
system_role: Mapped[str] = mapped_column(String(16), default="user")
|
||||||
|
|
||||||
|
password_hash: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||||
|
oauth_provider: Mapped[str | None] = mapped_column(String(32), default=None)
|
||||||
|
oauth_id: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||||
|
needs_setup: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
token_version: Mapped[int] = mapped_column(default=0)
|
||||||
|
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
TimeZone,
|
||||||
|
init=False,
|
||||||
|
default_factory=current_time,
|
||||||
|
sort_order=999,
|
||||||
|
comment="Created at",
|
||||||
|
)
|
||||||
@@ -0,0 +1,3 @@
|
|||||||
|
from .timezone import get_timezone
|
||||||
|
|
||||||
|
__all__ = ["get_timezone"]
|
||||||
@@ -0,0 +1,51 @@
|
|||||||
|
import zoneinfo
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
|
||||||
|
from store.config.app_config import get_app_config
|
||||||
|
|
||||||
|
# IANA identifiers that map to UTC — see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
|
||||||
|
_UTC_IDENTIFIERS = frozenset({"Etc/UCT", "Etc/Universal", "Etc/UTC", "Etc/Zulu", "UCT", "Universal", "UTC", "Zulu"})
|
||||||
|
|
||||||
|
|
||||||
|
class TimeZone:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
app_config = get_app_config()
|
||||||
|
if app_config.timezone in _UTC_IDENTIFIERS:
|
||||||
|
self.tz_info = UTC
|
||||||
|
else:
|
||||||
|
self.tz_info = zoneinfo.ZoneInfo(app_config.timezone)
|
||||||
|
|
||||||
|
def now(self) -> datetime:
|
||||||
|
"""Return the current time in the configured timezone."""
|
||||||
|
return datetime.now(self.tz_info)
|
||||||
|
|
||||||
|
def from_datetime(self, t: datetime) -> datetime:
|
||||||
|
"""Convert a datetime to the configured timezone."""
|
||||||
|
return t.astimezone(self.tz_info)
|
||||||
|
|
||||||
|
def from_str(self, t_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime:
|
||||||
|
"""Parse a time string and attach the configured timezone."""
|
||||||
|
return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_str(t: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||||
|
"""Format a datetime to string."""
|
||||||
|
return t.strftime(format_str)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def to_utc(t: datetime | int) -> datetime:
|
||||||
|
"""Convert a datetime or Unix timestamp to UTC."""
|
||||||
|
if isinstance(t, datetime):
|
||||||
|
return t.astimezone(UTC)
|
||||||
|
return datetime.fromtimestamp(t, tz=UTC)
|
||||||
|
|
||||||
|
|
||||||
|
_timezone = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_timezone() -> TimeZone:
|
||||||
|
"""Return the global TimeZone singleton (lazy-initialized)."""
|
||||||
|
global _timezone
|
||||||
|
if _timezone is None:
|
||||||
|
_timezone = TimeZone()
|
||||||
|
return _timezone
|
||||||
@@ -6,6 +6,7 @@ readme = "README.md"
|
|||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"deerflow-harness",
|
"deerflow-harness",
|
||||||
|
"deerflow-storage",
|
||||||
"fastapi>=0.115.0",
|
"fastapi>=0.115.0",
|
||||||
"httpx>=0.28.0",
|
"httpx>=0.28.0",
|
||||||
"python-multipart>=0.0.27",
|
"python-multipart>=0.0.27",
|
||||||
@@ -24,7 +25,8 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
postgres = ["deerflow-harness[postgres]"]
|
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
|
||||||
|
mysql = ["deerflow-storage[mysql]"]
|
||||||
|
|
||||||
[dependency-groups]
|
[dependency-groups]
|
||||||
dev = [
|
dev = [
|
||||||
@@ -43,7 +45,8 @@ markers = [
|
|||||||
index-url = "https://pypi.org/simple"
|
index-url = "https://pypi.org/simple"
|
||||||
|
|
||||||
[tool.uv.workspace]
|
[tool.uv.workspace]
|
||||||
members = ["packages/harness"]
|
members = ["packages/harness", "packages/storage"]
|
||||||
|
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
deerflow-harness = { workspace = true }
|
deerflow-harness = { workspace = true }
|
||||||
|
deerflow-storage = { workspace = true }
|
||||||
|
|||||||
@@ -94,12 +94,15 @@ class TestHarnessPackaging:
|
|||||||
"psycopg-pool>=3.3.0",
|
"psycopg-pool>=3.3.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
|
def test_workspace_pyproject_forwards_postgres_extra_to_storage_packages(self):
|
||||||
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
|
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
|
||||||
data = tomllib.loads(pyproject_path.read_text())
|
data = tomllib.loads(pyproject_path.read_text())
|
||||||
|
|
||||||
optional_dependencies = data["project"]["optional-dependencies"]
|
optional_dependencies = data["project"]["optional-dependencies"]
|
||||||
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
|
assert optional_dependencies["postgres"] == [
|
||||||
|
"deerflow-harness[postgres]",
|
||||||
|
"deerflow-storage[postgres]",
|
||||||
|
]
|
||||||
|
|
||||||
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
|
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
|
||||||
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
|
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
|
||||||
|
|||||||
@@ -0,0 +1,81 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import Column, MetaData, String, Table
|
||||||
|
from sqlalchemy.dialects import mysql, postgresql
|
||||||
|
from sqlalchemy.types import JSON
|
||||||
|
|
||||||
|
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||||
|
|
||||||
|
from store.persistence.json_compat import json_match
|
||||||
|
|
||||||
|
|
||||||
|
def _table():
|
||||||
|
metadata = MetaData()
|
||||||
|
return Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_json_match_compiles_sqlite() -> None:
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
|
||||||
|
table = _table()
|
||||||
|
dialect = create_engine("sqlite://").dialect
|
||||||
|
|
||||||
|
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'null'")
|
||||||
|
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'true'")
|
||||||
|
|
||||||
|
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "= 'integer'" in int_sql
|
||||||
|
assert "CAST" in int_sql
|
||||||
|
|
||||||
|
float_sql = str(json_match(table.c.data, "k", 3.14).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "IN ('integer', 'real')" in float_sql
|
||||||
|
assert "REAL" in float_sql
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_json_match_compiles_postgres() -> None:
|
||||||
|
table = _table()
|
||||||
|
dialect = postgresql.dialect()
|
||||||
|
|
||||||
|
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_typeof(t.data -> 'k') = 'null'")
|
||||||
|
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')")
|
||||||
|
|
||||||
|
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "CASE WHEN" in int_sql
|
||||||
|
assert "BIGINT" in int_sql
|
||||||
|
assert "'^-?[0-9]+$'" in int_sql
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_json_match_compiles_mysql() -> None:
|
||||||
|
table = _table()
|
||||||
|
dialect = mysql.dialect()
|
||||||
|
|
||||||
|
null_sql = str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert null_sql == "JSON_TYPE(JSON_EXTRACT(t.data, '$.\"k\"')) = 'NULL'"
|
||||||
|
|
||||||
|
bool_sql = str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "JSON_TYPE(JSON_EXTRACT" in bool_sql
|
||||||
|
assert "= 'BOOLEAN'" in bool_sql
|
||||||
|
assert "= 'true'" in bool_sql
|
||||||
|
|
||||||
|
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||||
|
assert "= 'INTEGER'" in int_sql
|
||||||
|
assert "SIGNED" in int_sql
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_json_match_rejects_unsafe_keys_and_values() -> None:
|
||||||
|
table = _table()
|
||||||
|
|
||||||
|
for bad_key in ["a.b", "bad;key", "with space", "", 42, None]:
|
||||||
|
with pytest.raises(ValueError, match="JsonMatch key must match"):
|
||||||
|
json_match(table.c.data, bad_key, "x") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
for bad_value in [[], {}, object()]:
|
||||||
|
with pytest.raises(TypeError, match="JsonMatch value must be"):
|
||||||
|
json_match(table.c.data, "k", bad_value)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError, match="out of signed 64-bit range"):
|
||||||
|
json_match(table.c.data, "k", 2**63)
|
||||||
@@ -0,0 +1,122 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||||
|
|
||||||
|
from store.config.storage_config import StorageConfig
|
||||||
|
from store.persistence.factory import _create_database_url, storage_config_from_database_config
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_sqlite_config_maps_to_storage_config(tmp_path):
|
||||||
|
database = SimpleNamespace(
|
||||||
|
backend="sqlite",
|
||||||
|
sqlite_dir=str(tmp_path),
|
||||||
|
echo_sql=True,
|
||||||
|
pool_size=9,
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = storage_config_from_database_config(database)
|
||||||
|
|
||||||
|
assert storage == StorageConfig(
|
||||||
|
driver="sqlite",
|
||||||
|
sqlite_dir=str(tmp_path),
|
||||||
|
echo_sql=True,
|
||||||
|
pool_size=9,
|
||||||
|
)
|
||||||
|
assert storage.sqlite_storage_path == str(tmp_path / "deerflow.db")
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_memory_config_is_not_a_storage_backend():
|
||||||
|
database = SimpleNamespace(backend="memory")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||||
|
storage_config_from_database_config(database)
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_postgres_config_preserves_url_and_pool_options():
|
||||||
|
database = SimpleNamespace(
|
||||||
|
backend="postgres",
|
||||||
|
postgres_url="postgresql://user:pass@db.example:5544/deerflow",
|
||||||
|
echo_sql=True,
|
||||||
|
pool_size=11,
|
||||||
|
)
|
||||||
|
|
||||||
|
storage = storage_config_from_database_config(database)
|
||||||
|
url = _create_database_url(storage)
|
||||||
|
|
||||||
|
assert storage.driver == "postgres"
|
||||||
|
assert storage.database_url == "postgresql://user:pass@db.example:5544/deerflow"
|
||||||
|
assert storage.username == "user"
|
||||||
|
assert storage.password == "pass"
|
||||||
|
assert storage.host == "db.example"
|
||||||
|
assert storage.port == 5544
|
||||||
|
assert storage.db_name == "deerflow"
|
||||||
|
assert storage.echo_sql is True
|
||||||
|
assert storage.pool_size == 11
|
||||||
|
assert url.drivername == "postgresql+asyncpg"
|
||||||
|
assert url.database == "deerflow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mysql_database_url_is_normalized_to_async_driver():
|
||||||
|
storage = StorageConfig(
|
||||||
|
driver="mysql",
|
||||||
|
database_url="mysql://user:pass@db.example:3306/deerflow",
|
||||||
|
)
|
||||||
|
|
||||||
|
url = _create_database_url(storage)
|
||||||
|
|
||||||
|
assert url.drivername == "mysql+aiomysql"
|
||||||
|
assert url.database == "deerflow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_mysql_async_database_url_is_preserved():
|
||||||
|
storage = StorageConfig(
|
||||||
|
driver="mysql",
|
||||||
|
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
|
||||||
|
)
|
||||||
|
|
||||||
|
url = _create_database_url(storage)
|
||||||
|
|
||||||
|
assert url.drivername == "mysql+asyncmy"
|
||||||
|
assert url.database == "deerflow"
|
||||||
|
|
||||||
|
|
||||||
|
def test_database_postgres_requires_url():
|
||||||
|
database = SimpleNamespace(backend="postgres", postgres_url="")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="database.postgres_url is required"):
|
||||||
|
storage_config_from_database_config(database)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unsupported_database_backend_rejected():
|
||||||
|
database = SimpleNamespace(backend="oracle")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||||
|
storage_config_from_database_config(database)
|
||||||
|
|
||||||
|
|
||||||
|
def test_storage_models_import_without_config_file(tmp_path):
|
||||||
|
env = os.environ.copy()
|
||||||
|
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
|
||||||
|
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
"-c",
|
||||||
|
"from store.persistence.base_model import UniversalText, id_key; from store.repositories.models import RunEvent; print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
|
||||||
|
],
|
||||||
|
check=False,
|
||||||
|
capture_output=True,
|
||||||
|
env=env,
|
||||||
|
text=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.returncode == 0, result.stderr
|
||||||
|
assert "UniversalText run_events" in result.stdout
|
||||||
@@ -0,0 +1,58 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||||
|
|
||||||
|
from sqlalchemy import inspect
|
||||||
|
from store.persistence import create_persistence_from_database_config
|
||||||
|
from store.repositories import UserCreate, build_user_repository
|
||||||
|
|
||||||
|
|
||||||
|
def test_sqlite_persistence_from_database_config_creates_storage_tables(tmp_path):
|
||||||
|
async def run() -> None:
|
||||||
|
persistence = await create_persistence_from_database_config(
|
||||||
|
SimpleNamespace(
|
||||||
|
backend="sqlite",
|
||||||
|
sqlite_dir=str(tmp_path),
|
||||||
|
echo_sql=False,
|
||||||
|
pool_size=5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert persistence is not None
|
||||||
|
try:
|
||||||
|
await persistence.setup()
|
||||||
|
|
||||||
|
async with persistence.engine.connect() as conn:
|
||||||
|
tables = await conn.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names()))
|
||||||
|
|
||||||
|
assert {
|
||||||
|
"users",
|
||||||
|
"runs",
|
||||||
|
"run_events",
|
||||||
|
"threads_meta",
|
||||||
|
"feedback",
|
||||||
|
}.issubset(tables)
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
user = await repo.create_user(
|
||||||
|
UserCreate(
|
||||||
|
id=str(uuid4()),
|
||||||
|
email="storage-user@example.com",
|
||||||
|
password_hash="hash",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
assert await repo.get_user_by_id(user.id) == user
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
@@ -0,0 +1,395 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||||
|
|
||||||
|
from store.persistence import create_persistence_from_database_config
|
||||||
|
from store.repositories import (
|
||||||
|
FeedbackCreate,
|
||||||
|
InvalidMetadataFilterError,
|
||||||
|
RunCreate,
|
||||||
|
RunEventCreate,
|
||||||
|
ThreadMetaCreate,
|
||||||
|
build_feedback_repository,
|
||||||
|
build_run_event_repository,
|
||||||
|
build_run_repository,
|
||||||
|
build_thread_meta_repository,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_persistence(tmp_path):
|
||||||
|
persistence = await create_persistence_from_database_config(
|
||||||
|
SimpleNamespace(
|
||||||
|
backend="sqlite",
|
||||||
|
sqlite_dir=str(tmp_path),
|
||||||
|
echo_sql=False,
|
||||||
|
pool_size=5,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await persistence.setup()
|
||||||
|
return persistence
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_run_repository_filters_and_aggregates(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
old = datetime.now(UTC) - timedelta(hours=1)
|
||||||
|
newer = datetime.now(UTC)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_run_repository(session)
|
||||||
|
await repo.create_run(
|
||||||
|
RunCreate(
|
||||||
|
run_id="run-old",
|
||||||
|
thread_id="thread-1",
|
||||||
|
user_id="alice",
|
||||||
|
status="pending",
|
||||||
|
model_name="model-a",
|
||||||
|
metadata={"kind": "draft"},
|
||||||
|
kwargs={"temperature": 0.2},
|
||||||
|
created_time=old,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await repo.create_run(
|
||||||
|
RunCreate(
|
||||||
|
run_id="run-new",
|
||||||
|
thread_id="thread-1",
|
||||||
|
user_id="bob",
|
||||||
|
status="running",
|
||||||
|
model_name="model-b",
|
||||||
|
error="queued",
|
||||||
|
created_time=newer,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running"))
|
||||||
|
await repo.update_run_completion(
|
||||||
|
"run-old",
|
||||||
|
status="success",
|
||||||
|
total_input_tokens=7,
|
||||||
|
total_output_tokens=3,
|
||||||
|
total_tokens=10,
|
||||||
|
llm_call_count=1,
|
||||||
|
lead_agent_tokens=8,
|
||||||
|
subagent_tokens=2,
|
||||||
|
first_human_message="hello",
|
||||||
|
last_ai_message="world",
|
||||||
|
)
|
||||||
|
await repo.update_run_completion(
|
||||||
|
"run-new",
|
||||||
|
status="error",
|
||||||
|
total_tokens=5,
|
||||||
|
middleware_tokens=5,
|
||||||
|
error="failed",
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_run_repository(session)
|
||||||
|
fetched = await repo.get_run("run-old")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.metadata == {"kind": "draft"}
|
||||||
|
assert fetched.kwargs == {"temperature": 0.2}
|
||||||
|
assert fetched.first_human_message == "hello"
|
||||||
|
assert fetched.last_ai_message == "world"
|
||||||
|
|
||||||
|
all_thread_runs = await repo.list_runs_by_thread("thread-1")
|
||||||
|
assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"]
|
||||||
|
alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice")
|
||||||
|
assert [run.run_id for run in alice_runs] == ["run-old"]
|
||||||
|
|
||||||
|
pending = await repo.list_pending(before=datetime.now(UTC).isoformat())
|
||||||
|
assert [run.run_id for run in pending] == []
|
||||||
|
|
||||||
|
agg = await repo.aggregate_tokens_by_thread("thread-1")
|
||||||
|
assert agg["total_tokens"] == 15
|
||||||
|
assert agg["total_input_tokens"] == 7
|
||||||
|
assert agg["total_output_tokens"] == 3
|
||||||
|
assert agg["total_runs"] == 2
|
||||||
|
assert agg["by_model"] == {
|
||||||
|
"model-a": {"tokens": 10, "runs": 1},
|
||||||
|
"model-b": {"tokens": 5, "runs": 1},
|
||||||
|
}
|
||||||
|
assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5}
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
await repo.create_thread_meta(
|
||||||
|
ThreadMetaCreate(
|
||||||
|
thread_id="thread-1",
|
||||||
|
assistant_id="agent-a",
|
||||||
|
user_id="alice",
|
||||||
|
display_name="Initial",
|
||||||
|
status="idle",
|
||||||
|
metadata={"topic": "finance", "region": "cn"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await repo.create_thread_meta(
|
||||||
|
ThreadMetaCreate(
|
||||||
|
thread_id="thread-2",
|
||||||
|
assistant_id="agent-b",
|
||||||
|
user_id="bob",
|
||||||
|
status="running",
|
||||||
|
metadata={"topic": "legal"},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await repo.update_thread_meta(
|
||||||
|
"thread-1",
|
||||||
|
display_name="Updated",
|
||||||
|
status="running",
|
||||||
|
metadata={"topic": "finance", "region": "us"},
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
fetched = await repo.get_thread_meta("thread-1")
|
||||||
|
assert fetched is not None
|
||||||
|
assert fetched.display_name == "Updated"
|
||||||
|
assert fetched.status == "running"
|
||||||
|
assert fetched.metadata == {"topic": "finance", "region": "us"}
|
||||||
|
|
||||||
|
by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice")
|
||||||
|
assert [thread.thread_id for thread in by_metadata] == ["thread-1"]
|
||||||
|
by_assistant = await repo.search_threads(assistant_id="agent-b")
|
||||||
|
assert [thread.thread_id for thread in by_assistant] == ["thread-2"]
|
||||||
|
|
||||||
|
await repo.delete_thread("thread-1")
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
assert await repo.get_thread_meta("thread-1") is None
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_thread_meta_metadata_filters_are_type_safe(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-true", metadata={"value": True}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-false", metadata={"value": False}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="int-one", metadata={"value": 1}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="null-value", metadata={"value": None}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="missing-value", metadata={"other": "x"}))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
assert [row.thread_id for row in await repo.search_threads(metadata={"value": True})] == ["bool-true"]
|
||||||
|
assert [row.thread_id for row in await repo.search_threads(metadata={"value": False})] == ["bool-false"]
|
||||||
|
assert [row.thread_id for row in await repo.search_threads(metadata={"value": 1})] == ["int-one"]
|
||||||
|
assert [row.thread_id for row in await repo.search_threads(metadata={"value": None})] == ["null-value"]
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_thread_meta_metadata_filters_paginate_after_sql_match(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
for index in range(30):
|
||||||
|
metadata = {"target": "yes"} if index % 3 == 0 else {"target": "no"}
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id=f"thread-{index:02d}", metadata=metadata))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
first_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=0)
|
||||||
|
second_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=3)
|
||||||
|
last_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=9)
|
||||||
|
|
||||||
|
assert len(first_page) == 3
|
||||||
|
assert len(second_page) == 3
|
||||||
|
assert len(last_page) == 1
|
||||||
|
assert {row.thread_id for row in first_page}.isdisjoint({row.thread_id for row in second_page})
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_thread_meta_metadata_filter_rejects_invalid_entries(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-1", metadata={"env": "prod"}))
|
||||||
|
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-2", metadata={"env": "staging"}))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_thread_meta_repository(session)
|
||||||
|
partial = await repo.search_threads(metadata={"env": "prod", "bad;key": "ignored"})
|
||||||
|
assert [row.thread_id for row in partial] == ["thread-1"]
|
||||||
|
|
||||||
|
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||||
|
await repo.search_threads(metadata={"bad;key": "x"})
|
||||||
|
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||||
|
await repo.search_threads(metadata={"env": ["prod", "staging"]})
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_feedback_repository(session)
|
||||||
|
first = await repo.create_feedback(
|
||||||
|
FeedbackCreate(
|
||||||
|
feedback_id="fb-1",
|
||||||
|
run_id="run-1",
|
||||||
|
thread_id="thread-1",
|
||||||
|
rating=1,
|
||||||
|
user_id="alice",
|
||||||
|
message_id="msg-1",
|
||||||
|
comment="good",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
second = await repo.create_feedback(
|
||||||
|
FeedbackCreate(
|
||||||
|
feedback_id="fb-2",
|
||||||
|
run_id="run-1",
|
||||||
|
thread_id="thread-1",
|
||||||
|
rating=-1,
|
||||||
|
user_id="bob",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_feedback_repository(session)
|
||||||
|
assert await repo.get_feedback(first.feedback_id) == first
|
||||||
|
assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [
|
||||||
|
second.feedback_id,
|
||||||
|
first.feedback_id,
|
||||||
|
]
|
||||||
|
assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == {
|
||||||
|
"fb-1",
|
||||||
|
"fb-2",
|
||||||
|
}
|
||||||
|
assert await repo.delete_feedback("fb-1") is True
|
||||||
|
assert await repo.delete_feedback("missing") is False
|
||||||
|
with pytest.raises(ValueError, match="rating must be"):
|
||||||
|
await repo.create_feedback(
|
||||||
|
FeedbackCreate(
|
||||||
|
feedback_id="fb-bad",
|
||||||
|
run_id="run-1",
|
||||||
|
thread_id="thread-1",
|
||||||
|
rating=0,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_feedback_repository(session)
|
||||||
|
assert await repo.get_feedback("fb-1") is None
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.anyio
|
||||||
|
async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path):
|
||||||
|
persistence = await _make_persistence(tmp_path)
|
||||||
|
try:
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
rows = await repo.append_batch(
|
||||||
|
[
|
||||||
|
RunEventCreate(
|
||||||
|
thread_id="thread-1",
|
||||||
|
run_id="run-1",
|
||||||
|
user_id="alice",
|
||||||
|
event_type="message",
|
||||||
|
category="message",
|
||||||
|
content={"role": "user", "content": "hello"},
|
||||||
|
metadata={"source": "input"},
|
||||||
|
),
|
||||||
|
RunEventCreate(
|
||||||
|
thread_id="thread-1",
|
||||||
|
run_id="run-1",
|
||||||
|
event_type="tool",
|
||||||
|
category="debug",
|
||||||
|
content="tool-call",
|
||||||
|
),
|
||||||
|
RunEventCreate(
|
||||||
|
thread_id="thread-1",
|
||||||
|
run_id="run-2",
|
||||||
|
event_type="message",
|
||||||
|
category="message",
|
||||||
|
content="second",
|
||||||
|
),
|
||||||
|
RunEventCreate(
|
||||||
|
thread_id="thread-2",
|
||||||
|
run_id="run-3",
|
||||||
|
event_type="message",
|
||||||
|
category="message",
|
||||||
|
content="other-thread",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
|
||||||
|
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
|
||||||
|
assert rows[1].seq == rows[0].seq + 1
|
||||||
|
assert rows[2].seq == rows[1].seq + 1
|
||||||
|
assert rows[0].content == {"role": "user", "content": "hello"}
|
||||||
|
assert rows[0].metadata == {"source": "input", "content_is_json": True}
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
messages = await repo.list_messages("thread-1", limit=2)
|
||||||
|
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
|
||||||
|
assert await repo.count_messages("thread-1") == 2
|
||||||
|
|
||||||
|
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
|
||||||
|
assert [event.seq for event in after] == [rows[0].seq]
|
||||||
|
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
|
||||||
|
assert [event.seq for event in before] == [rows[0].seq]
|
||||||
|
|
||||||
|
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
|
||||||
|
assert [event.content for event in events] == ["tool-call"]
|
||||||
|
|
||||||
|
assert await repo.delete_by_run("thread-1", "run-1") == 2
|
||||||
|
assert await repo.delete_by_thread("thread-2") == 1
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with persistence.session_factory() as session:
|
||||||
|
repo = build_run_event_repository(session)
|
||||||
|
remaining = await repo.list_events("thread-1", "run-2")
|
||||||
|
assert [event.seq for event in remaining] == [rows[2].seq]
|
||||||
|
assert await repo.count_messages("thread-2") == 0
|
||||||
|
|
||||||
|
later = await repo.append_batch(
|
||||||
|
[
|
||||||
|
RunEventCreate(
|
||||||
|
thread_id="thread-1",
|
||||||
|
run_id="run-4",
|
||||||
|
event_type="message",
|
||||||
|
category="message",
|
||||||
|
content="after-delete",
|
||||||
|
)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
assert later[0].seq > rows[2].seq
|
||||||
|
finally:
|
||||||
|
await persistence.aclose()
|
||||||
@@ -0,0 +1,177 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from store.repositories import UserCreate, UserNotFoundError, build_user_repository
|
||||||
|
from store.repositories.models import User as UserModel
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]:
|
||||||
|
db_path = tmp_path / "storage-users.db"
|
||||||
|
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
await conn.run_sync(UserModel.metadata.create_all)
|
||||||
|
|
||||||
|
try:
|
||||||
|
yield async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
finally:
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
async def _create_user(
|
||||||
|
session_factory: async_sessionmaker[AsyncSession],
|
||||||
|
*,
|
||||||
|
email: str = "user@example.com",
|
||||||
|
system_role: str = "user",
|
||||||
|
oauth_provider: str | None = None,
|
||||||
|
oauth_id: str | None = None,
|
||||||
|
):
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
user = await repo.create_user(
|
||||||
|
UserCreate(
|
||||||
|
id=str(uuid4()),
|
||||||
|
email=email,
|
||||||
|
password_hash="hash",
|
||||||
|
system_role=system_role, # type: ignore[arg-type]
|
||||||
|
oauth_provider=oauth_provider,
|
||||||
|
oauth_id=oauth_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_and_get_user_by_id_and_email(tmp_path):
|
||||||
|
async def run() -> None:
|
||||||
|
async with _session_factory(tmp_path) as session_factory:
|
||||||
|
created = await _create_user(session_factory)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
|
||||||
|
by_id = await repo.get_user_by_id(created.id)
|
||||||
|
by_email = await repo.get_user_by_email(created.email)
|
||||||
|
|
||||||
|
assert by_id == created
|
||||||
|
assert by_email == created
|
||||||
|
assert created.system_role == "user"
|
||||||
|
assert created.needs_setup is False
|
||||||
|
assert created.token_version == 0
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_email_raises_value_error(tmp_path):
|
||||||
|
async def run() -> None:
|
||||||
|
async with _session_factory(tmp_path) as session_factory:
|
||||||
|
await _create_user(session_factory, email="dupe@example.com")
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
with pytest.raises(ValueError, match="Email already registered"):
|
||||||
|
await repo.create_user(
|
||||||
|
UserCreate(
|
||||||
|
id=str(uuid4()),
|
||||||
|
email="dupe@example.com",
|
||||||
|
password_hash="hash",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_oauth_lookup_and_plain_users_without_oauth(tmp_path):
|
||||||
|
async def run() -> None:
|
||||||
|
async with _session_factory(tmp_path) as session_factory:
|
||||||
|
await _create_user(session_factory, email="local-1@example.com")
|
||||||
|
await _create_user(session_factory, email="local-2@example.com")
|
||||||
|
oauth_user = await _create_user(
|
||||||
|
session_factory,
|
||||||
|
email="oauth@example.com",
|
||||||
|
oauth_provider="github",
|
||||||
|
oauth_id="gh-123",
|
||||||
|
)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
|
||||||
|
assert await repo.count_users() == 3
|
||||||
|
assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user
|
||||||
|
assert await repo.get_user_by_oauth("github", "missing") is None
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_count_admins_and_get_first_admin(tmp_path):
|
||||||
|
async def run() -> None:
|
||||||
|
async with _session_factory(tmp_path) as session_factory:
|
||||||
|
await _create_user(session_factory, email="user@example.com")
|
||||||
|
admin = await _create_user(
|
||||||
|
session_factory,
|
||||||
|
email="admin@example.com",
|
||||||
|
system_role="admin",
|
||||||
|
)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
|
||||||
|
assert await repo.count_users() == 2
|
||||||
|
assert await repo.count_admin_users() == 1
|
||||||
|
assert await repo.get_first_admin() == admin
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_user_round_trips_token_version_and_setup_state(tmp_path):
|
||||||
|
async def run() -> None:
|
||||||
|
async with _session_factory(tmp_path) as session_factory:
|
||||||
|
created = await _create_user(session_factory)
|
||||||
|
updated = created.model_copy(
|
||||||
|
update={
|
||||||
|
"email": "renamed@example.com",
|
||||||
|
"token_version": 4,
|
||||||
|
"needs_setup": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
saved = await repo.update_user(updated)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
fetched = await repo.get_user_by_id(created.id)
|
||||||
|
|
||||||
|
assert saved.email == "renamed@example.com"
|
||||||
|
assert fetched == updated
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_missing_user_raises(tmp_path):
|
||||||
|
async def run() -> None:
|
||||||
|
async with _session_factory(tmp_path) as session_factory:
|
||||||
|
missing = UserCreate(id=str(uuid4()), email="missing@example.com")
|
||||||
|
|
||||||
|
async with session_factory() as session:
|
||||||
|
repo = build_user_repository(session)
|
||||||
|
created_shape = await repo.create_user(missing)
|
||||||
|
await session.rollback()
|
||||||
|
|
||||||
|
with pytest.raises(UserNotFoundError):
|
||||||
|
await repo.update_user(created_shape)
|
||||||
|
|
||||||
|
asyncio.run(run())
|
||||||
@@ -1214,11 +1214,12 @@ def test_terminal_event_usage_none_when_no_records(monkeypatch):
|
|||||||
assert completed[0]["usage"] is None
|
assert completed[0]["usage"] is None
|
||||||
|
|
||||||
|
|
||||||
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
|
@pytest.mark.parametrize("error", [FileNotFoundError("missing config"), ValueError("invalid config")])
|
||||||
|
def test_subagent_usage_cache_is_skipped_when_default_config_cannot_load(monkeypatch, error):
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
task_tool_module,
|
task_tool_module,
|
||||||
"get_app_config",
|
"get_app_config",
|
||||||
MagicMock(side_effect=FileNotFoundError("missing config")),
|
MagicMock(side_effect=error),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert task_tool_module._token_usage_cache_enabled(None) is False
|
assert task_tool_module._token_usage_cache_enabled(None) is False
|
||||||
|
|||||||
Generated
+93
-1
@@ -14,6 +14,7 @@ resolution-markers = [
|
|||||||
members = [
|
members = [
|
||||||
"deer-flow",
|
"deer-flow",
|
||||||
"deerflow-harness",
|
"deerflow-harness",
|
||||||
|
"deerflow-storage",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -136,6 +137,18 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441, upload-time = "2026-03-31T22:00:12.791Z" },
|
{ url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441, upload-time = "2026-03-31T22:00:12.791Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "aiomysql"
|
||||||
|
version = "0.3.2"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "pymysql" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/29/e0/302aeffe8d90853556f47f3106b89c16cc2ec2a4d269bdfd82e3f4ae12cc/aiomysql-0.3.2.tar.gz", hash = "sha256:72d15ef5cfc34c03468eb41e1b90adb9fd9347b0b589114bd23ead569a02ac1a", size = 108311, upload-time = "2025-10-22T00:15:21.278Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/4c/af/aae0153c3e28712adaf462328f6c7a3c196a1c1c27b491de4377dd3e6b52/aiomysql-0.3.2-py3-none-any.whl", hash = "sha256:c82c5ba04137d7afd5c693a258bea8ead2aad77101668044143a991e04632eb2", size = 71834, upload-time = "2025-10-22T00:15:15.905Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "aiosignal"
|
name = "aiosignal"
|
||||||
version = "1.4.0"
|
version = "1.4.0"
|
||||||
@@ -746,6 +759,7 @@ source = { virtual = "." }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "bcrypt" },
|
{ name = "bcrypt" },
|
||||||
{ name = "deerflow-harness" },
|
{ name = "deerflow-harness" },
|
||||||
|
{ name = "deerflow-storage" },
|
||||||
{ name = "dingtalk-stream" },
|
{ name = "dingtalk-stream" },
|
||||||
{ name = "email-validator" },
|
{ name = "email-validator" },
|
||||||
{ name = "fastapi" },
|
{ name = "fastapi" },
|
||||||
@@ -763,8 +777,12 @@ dependencies = [
|
|||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
|
mysql = [
|
||||||
|
{ name = "deerflow-storage", extra = ["mysql"] },
|
||||||
|
]
|
||||||
postgres = [
|
postgres = [
|
||||||
{ name = "deerflow-harness", extra = ["postgres"] },
|
{ name = "deerflow-harness", extra = ["postgres"] },
|
||||||
|
{ name = "deerflow-storage", extra = ["postgres"] },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.dev-dependencies]
|
[package.dev-dependencies]
|
||||||
@@ -780,6 +798,9 @@ requires-dist = [
|
|||||||
{ name = "bcrypt", specifier = ">=4.0.0" },
|
{ name = "bcrypt", specifier = ">=4.0.0" },
|
||||||
{ name = "deerflow-harness", editable = "packages/harness" },
|
{ name = "deerflow-harness", editable = "packages/harness" },
|
||||||
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
||||||
|
{ name = "deerflow-storage", editable = "packages/storage" },
|
||||||
|
{ name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" },
|
||||||
|
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
|
||||||
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
||||||
{ name = "email-validator", specifier = ">=2.0.0" },
|
{ name = "email-validator", specifier = ">=2.0.0" },
|
||||||
{ name = "fastapi", specifier = ">=0.115.0" },
|
{ name = "fastapi", specifier = ">=0.115.0" },
|
||||||
@@ -795,7 +816,7 @@ requires-dist = [
|
|||||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
||||||
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
||||||
]
|
]
|
||||||
provides-extras = ["postgres"]
|
provides-extras = ["postgres", "mysql"]
|
||||||
|
|
||||||
[package.metadata.requires-dev]
|
[package.metadata.requires-dev]
|
||||||
dev = [
|
dev = [
|
||||||
@@ -901,6 +922,54 @@ requires-dist = [
|
|||||||
]
|
]
|
||||||
provides-extras = ["ollama", "postgres", "pymupdf"]
|
provides-extras = ["ollama", "postgres", "pymupdf"]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "deerflow-storage"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = { editable = "packages/storage" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "alembic" },
|
||||||
|
{ name = "dotenv" },
|
||||||
|
{ name = "langgraph" },
|
||||||
|
{ name = "pydantic" },
|
||||||
|
{ name = "pyyaml" },
|
||||||
|
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.optional-dependencies]
|
||||||
|
mysql = [
|
||||||
|
{ name = "aiomysql" },
|
||||||
|
{ name = "langgraph-checkpoint-mysql" },
|
||||||
|
]
|
||||||
|
postgres = [
|
||||||
|
{ name = "asyncpg" },
|
||||||
|
{ name = "langgraph-checkpoint-postgres" },
|
||||||
|
{ name = "psycopg", extra = ["binary"] },
|
||||||
|
{ name = "psycopg-pool" },
|
||||||
|
]
|
||||||
|
sqlite = [
|
||||||
|
{ name = "aiosqlite" },
|
||||||
|
{ name = "langgraph-checkpoint-sqlite" },
|
||||||
|
]
|
||||||
|
|
||||||
|
[package.metadata]
|
||||||
|
requires-dist = [
|
||||||
|
{ name = "aiomysql", marker = "extra == 'mysql'", specifier = ">=0.2" },
|
||||||
|
{ name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.22.1" },
|
||||||
|
{ name = "alembic", specifier = ">=1.13" },
|
||||||
|
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
|
||||||
|
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||||
|
{ name = "langgraph", specifier = ">=1.1.9" },
|
||||||
|
{ name = "langgraph-checkpoint-mysql", marker = "extra == 'mysql'", specifier = ">=3.0.0" },
|
||||||
|
{ name = "langgraph-checkpoint-postgres", marker = "extra == 'postgres'", specifier = ">=3.0.5" },
|
||||||
|
{ name = "langgraph-checkpoint-sqlite", marker = "extra == 'sqlite'", specifier = ">=3.0.3" },
|
||||||
|
{ name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.3.3" },
|
||||||
|
{ name = "psycopg-pool", marker = "extra == 'postgres'", specifier = ">=3.3.0" },
|
||||||
|
{ name = "pydantic", specifier = ">=2.12.5" },
|
||||||
|
{ name = "pyyaml", specifier = ">=6.0.3" },
|
||||||
|
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<3.0" },
|
||||||
|
]
|
||||||
|
provides-extras = ["postgres", "mysql", "sqlite"]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "defusedxml"
|
name = "defusedxml"
|
||||||
version = "0.7.1"
|
version = "0.7.1"
|
||||||
@@ -1914,6 +1983,20 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/b9/5a/6dba29dd89b0a46ae21c707da0f9d17e94f27d3e481ed15bc99d6bd20aa6/langgraph_checkpoint-4.0.2-py3-none-any.whl", hash = "sha256:59b0f29216128a629c58dd07c98aa004f82f51805d5573126ffb419b753ff253", size = 51000, upload-time = "2026-04-15T21:02:59.096Z" },
|
{ url = "https://files.pythonhosted.org/packages/b9/5a/6dba29dd89b0a46ae21c707da0f9d17e94f27d3e481ed15bc99d6bd20aa6/langgraph_checkpoint-4.0.2-py3-none-any.whl", hash = "sha256:59b0f29216128a629c58dd07c98aa004f82f51805d5573126ffb419b753ff253", size = 51000, upload-time = "2026-04-15T21:02:59.096Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "langgraph-checkpoint-mysql"
|
||||||
|
version = "3.0.0"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
dependencies = [
|
||||||
|
{ name = "langgraph-checkpoint" },
|
||||||
|
{ name = "orjson" },
|
||||||
|
{ name = "typing-extensions" },
|
||||||
|
]
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/e0/4e/0a6c78e5d3f2ca1525903c2363e721873594b6b77dd83537a6369193c474/langgraph_checkpoint_mysql-3.0.0.tar.gz", hash = "sha256:006aaa089f4c2fbd7b2c113b800ccd3dbb95f92203e656451677256b4b4f880f", size = 213142, upload-time = "2026-01-23T11:11:15.74Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/08/68/343103a7fae05523f9cecabbec2babdb737e66b4bf6ea48ae00c685ed11c/langgraph_checkpoint_mysql-3.0.0-py3-none-any.whl", hash = "sha256:7560ccd16e7596a047e15a307cec12dbd88fdcaab45a75759e5c6adef22a27d1", size = 38009, upload-time = "2026-01-23T11:11:14.697Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "langgraph-checkpoint-postgres"
|
name = "langgraph-checkpoint-postgres"
|
||||||
version = "3.0.5"
|
version = "3.0.5"
|
||||||
@@ -3442,6 +3525,15 @@ wheels = [
|
|||||||
{ url = "https://files.pythonhosted.org/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" },
|
{ url = "https://files.pythonhosted.org/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pymysql"
|
||||||
|
version = "1.1.3"
|
||||||
|
source = { registry = "https://pypi.org/simple" }
|
||||||
|
sdist = { url = "https://files.pythonhosted.org/packages/7f/ec/8d45c920e90445f0b75c590b32851853ed319763b0d8dff8d283052da8cf/pymysql-1.1.3.tar.gz", hash = "sha256:e70ebf2047a4edf6138cf79c68ad418ef620af65900aa585c5e8bfc95044d43a", size = 48207, upload-time = "2026-05-01T09:09:54.532Z" }
|
||||||
|
wheels = [
|
||||||
|
{ url = "https://files.pythonhosted.org/packages/8e/dc/9085f3d6f497e9b25fb40d6e8ecef3ddbb5cf977a949b933624a299f5c16/pymysql-1.1.3-py3-none-any.whl", hash = "sha256:8164ba62c552f6105f3b11753352d0f16b90d1703ba67d81923d5a8a5d1c5289", size = 45356, upload-time = "2026-05-01T09:09:53.316Z" },
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pypdfium2"
|
name = "pypdfium2"
|
||||||
version = "5.7.1"
|
version = "5.7.1"
|
||||||
|
|||||||
Reference in New Issue
Block a user