Compare commits

..

17 Commits

Author SHA1 Message Date
rayhpeng d7a2fff7e0 fix(task): tolerate invalid default config in usage cache 2026-05-15 16:52:11 +08:00
rayhpeng eabd78ce4e Merge branch 'main' of https://github.com/bytedance/deer-flow into rayhpeng/storage-package-base 2026-05-15 16:41:34 +08:00
Nan Gao 45060a9ffc fix(runtime): avoid postgres aggregate row lock (#2962) 2026-05-15 10:32:09 +08:00
LawranceLiao 722c690f4f fix(memory): isolate queued memory updates by agent (#2941)
* fix(memory): isolate queued memory updates by agent

* fix(memory): include user in queue identity

* Potential fix for pull request finding

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>

* Fix the lint error

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
2026-05-15 10:26:35 +08:00
rayhpeng 533d3fbfee docs(storage): add storage package design 2026-05-14 11:27:55 +08:00
dependabot[bot] ba864112a3 chore(deps): bump langsmith from 0.7.36 to 0.8.0 in /backend (#2943)
Bumps [langsmith](https://github.com/langchain-ai/langsmith-sdk) from 0.7.36 to 0.8.0.
- [Release notes](https://github.com/langchain-ai/langsmith-sdk/releases)
- [Commits](https://github.com/langchain-ai/langsmith-sdk/compare/v0.7.36...v0.8.0)

---
updated-dependencies:
- dependency-name: langsmith
  dependency-version: 0.8.0
  dependency-type: indirect
...

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-14 11:02:58 +08:00
AochenShen99 6e8e6a969b test: add blocking IO detector (#2924)
* test: add blocking IO detector

* test: add blocking IO probe option

* test: harden blocking IO probe lifecycle

* test: move blocking io detector to support
2026-05-13 23:56:06 +08:00
YuJitang eab7ae3d62 feat: stream subagent token usage to header via terminal task events (#2882)
* feat: real-time subagent token usage display in header and per-turn

Backend:
- Persist subagent token usage to AIMessage.usage_metadata via
  TokenUsageMiddleware, so accumulateUsage() naturally includes
  subagent tokens without frontend state management
- Cache subagent usage by tool_call_id in task_tool, write back
  to the dispatching AIMessage on next model response
- Emit subagent token usage on all terminal task events
  (task_completed, task_failed, task_cancelled, task_timed_out)
- Report subagent usage to parent RunJournal for API totals
- Search backward from ToolMessage to find dispatching AIMessage
  for correct multi-tool-call attribution

Frontend:
- Remove subagentUsage state, custom event handling, and prop
  threading — subagent tokens are now embedded in message metadata
- Simplify selectHeaderTokenUsage (no subagentUsage parameter)
- Per-turn inline badges show turn-specific usage via message
  accumulation
- Remove isLoading guard from MessageTokenUsageList for dynamic
  updates during streaming

* fix: prevent header token double counting from baseline reset race

onFinish, onError, and thread-switch useEffect all reset
pendingUsageBaselineMessageIdsRef to an empty Set. If
thread.isLoading is still true on the next render, all messages
pass the getMessagesAfterBaseline filter and their tokens are
added to backendUsage (which already includes them), causing
the header to display up to 2× the actual token count.

Capture current message IDs instead of using an empty Set so
that getMessagesAfterBaseline correctly returns no pending
messages even if thread.isLoading lags behind the stream end.

* fix: write back subagent tokens for all concurrent task tool calls

TokenUsageMiddleware only processed messages[-2], so when a
single model response dispatched multiple task tool calls only
the last ToolMessage had its cached subagent usage written back
to the dispatch AIMessage.usage_metadata. Earlier tasks' usage
stayed in _subagent_usage_cache indefinitely (leak) and never
appeared in the per-turn inline token display.

Walk backward through all consecutive ToolMessages before the
new AIMessage, and accumulate updates targeting the same
dispatch message into one state update so overlapping writes
don't clobber each other.

* fix: clean up subagent usage cache entry on task cancellation

When a task_tool invocation is cancelled via CancelledError, any
cached subagent usage entry leaked because the TokenUsageMiddleware
writeback path never fires after cancellation. Pop the cache entry
before re-raising to prevent unbounded growth of the module-level
_subagent_usage_cache dict.

* fix: address token usage review feedback

* fix: handle missing config for subagent usage cache

---------

Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
2026-05-13 23:52:19 +08:00
Xinmin Zeng f1a0ab699a fix(tools): preserve tool_search promotions across re-entrant get_available_tools (#2885)
* fix(tools): preserve tool_search promotions across re-entrant get_available_tools

Closes #2884.

``get_available_tools`` used to unconditionally call
``reset_deferred_registry()`` and rebuild a fresh ``DeferredToolRegistry``
on every invocation. That works for the first call of a request (the
ContextVar starts at its default of ``None``), but any RE-ENTRANT call
during the same async context — e.g. ``task_tool`` building a subagent's
toolset, or a custom middleware that rebuilds tools mid-run — wiped any
``tool_search`` promotions the parent agent had already made. The
``DeferredToolFilterMiddleware`` would then re-hide those tools from the
next model call, leaving the agent able to see a tool's name (via the
prior ``tool_search`` result that's still in conversation history) but
unable to invoke it.

Fix: when the ContextVar already holds a registry, reuse it instead of
rebuilding. Fresh requests still get a fresh registry because each new
graph run starts in a new asyncio task with the ContextVar at ``None``.

## Verification

- Unit-level reproduction (``test_get_available_tools_resets_registry_wiping_promotion``):
  promote a tool in the registry, call ``get_available_tools`` again, assert
  the promotion is preserved. Fails on main, passes on this branch.

- Graph-execution reproduction (two tests): drive a real
  ``langchain.agents.create_agent`` graph with the real
  ``DeferredToolFilterMiddleware`` through two model turns, including one
  that issues a re-entrant ``get_available_tools`` call to simulate the
  task_tool subagent path.

- Real-LLM end-to-end (``test_deferred_tool_promotion_real_llm.py``,
  opt-in via ``ONEAPI_E2E=1``): drives the same flow against a real
  OpenAI-compatible model (verified on GPT-5.4-mini through the one-api
  gateway), watches the model call the promoted ``fake_calculator``
  through the deferred-filter middleware, and asserts the right arithmetic
  result. Passes against the fixed branch.

- Companion update to ``test_tool_deduplication.py``: dropped the
  ``@patch("deerflow.tools.tools.reset_deferred_registry")`` decorators
  because the symbol is no longer imported there.

- Test fixtures in the new files patch ``deerflow.tools.tools.get_app_config``
  with a minimal ``model_construct``-ed ``AppConfig`` instead of calling
  the real loader, so they never trigger ``_apply_singleton_configs`` and
  never leak ``_memory_config``/``_title_config``/… mutations into the
  rest of the suite.

Full backend suite: 3208 passed / 14 skipped / 0 failed. ruff check + format clean.

* fix(tools): address Copilot review on #2885

- tools.py: rewrite the reuse-path comment to spell out (a) why we don't
  reconcile the registry against the current ``mcp_tools`` snapshot — the
  MCP cache doesn't refresh mid-graph-run, the lead agent's ``ToolNode``
  is already bound to the previous tool set anyway, and ``promote()``
  drops the entry so a naive re-sync misclassifies promotions as new
  tools — and (b) why the log uses ``max(0, …)`` to avoid negative
  counts when the cache shrinks between snapshots.
- Replace direct ``ts_mod._registry_var.set(None)`` in test fixtures with
  the public ``reset_deferred_registry()`` helper so tests don't couple
  to module internals.
- Correct the docstring path in ``test_deferred_tool_registry_promotion.py``
  to match the actual monkeypatch target (``deerflow.mcp.cache.get_cached_mcp_tools``).
- Rename
  ``test_get_available_tools_resets_registry_wiping_promotion`` to
  ``test_get_available_tools_preserves_promotions_across_reentrant_calls``
  so the test name describes the contract being asserted, not the bug it
  originally reproduced.

Full backend suite: 3208 passed / 14 skipped. Real-LLM e2e: 1 passed.
2026-05-13 23:45:47 +08:00
Eilen Shin 2a1ac06bf4 fix(persistence): reuse token usage model grouping expression (#2910) 2026-05-13 15:49:34 +08:00
rayhpeng d6b3a277a5 test(storage): update postgres extra expectation 2026-05-13 14:18:35 +08:00
rayhpeng def2a3ad79 fix(storage): address user repository test comment 2026-05-13 13:03:06 +08:00
rayhpeng 3c0b42d836 fix(storage): address code quality comments 2026-05-13 12:58:17 +08:00
rayhpeng 34ec205e1d style(storage): format storage package 2026-05-13 12:52:34 +08:00
rayhpeng 11a9041b65 fix(storage): address repository review feedback 2026-05-13 12:51:45 +08:00
rayhpeng d3066a1746 fix(storage): harden sql persistence compatibility 2026-05-13 11:26:25 +08:00
rayhpeng 485f8a2bf2 feat(storage): add storage package base 2026-05-12 19:08:37 +08:00
83 changed files with 6388 additions and 121 deletions
+1 -1
View File
@@ -628,7 +628,7 @@ See [`skills/public/claude-to-deerflow/SKILL.md`](skills/public/claude-to-deerfl
Complex tasks rarely fit in a single pass. DeerFlow decomposes them.
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output.
The lead agent can spawn sub-agents on the fly — each with its own scoped context, tools, and termination conditions. Sub-agents run in parallel when possible, report back structured results, and the lead agent synthesizes everything into a coherent output. When token usage tracking is enabled, completed sub-agent usage is attributed back to the dispatching step.
This is how DeerFlow handles tasks that take minutes to hours: a research task might fan out into a dozen sub-agents, each exploring a different angle, then converge into a single report — or a website — or a slide deck with generated visuals. One harness, many hands.
+1 -1
View File
@@ -165,7 +165,7 @@ Lead-agent middlewares are assembled in strict append order across `packages/har
8. **ToolErrorHandlingMiddleware** - Converts tool exceptions into error `ToolMessage`s so the run can continue instead of aborting
9. **SummarizationMiddleware** - Context reduction when approaching token limits (optional, if enabled)
10. **TodoListMiddleware** - Task tracking with `write_todos` tool (optional, if plan_mode)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional)
11. **TokenUsageMiddleware** - Records token usage metrics when token tracking is enabled (optional); subagent usage is cached by `tool_call_id` only while token usage is enabled and merged back into the dispatching AIMessage by message position rather than message id
12. **TitleMiddleware** - Auto-generates thread title after first complete exchange and normalizes structured message content before prompting the title model
13. **MemoryMiddleware** - Queues conversations for async memory update (filters to user + final AI responses)
14. **ViewImageMiddleware** - Injects base64 image data before LLM call (conditional on vision support)
+401
View File
@@ -0,0 +1,401 @@
# Storage Package Design
## Background
DeerFlow currently has several persistence responsibilities spread across app, gateway, runtime, and legacy persistence modules. This makes the persistence boundary difficult to reason about and creates several migration risks:
- Routers and runtime services can accidentally depend on concrete persistence implementations instead of stable contracts.
- User/auth, run metadata, thread metadata, feedback, run events, and checkpointer setup are initialized through different paths.
- Some persistence behavior is duplicated between memory, SQLite, and PostgreSQL-oriented code paths.
- Incremental migration is hard because app-level code and storage-level code are coupled.
- Adding or validating another SQL backend requires touching app/runtime code instead of a storage-owned package.
The storage package is introduced to make application data persistence a package-level capability with explicit contracts, a clear boundary, and SQL backend compatibility.
## Goals
- Provide a standalone `packages/storage` package for durable application data.
- Support SQLite, PostgreSQL, and MySQL through a shared persistence construction flow.
- Keep LangGraph checkpointer initialization compatible with the same database backend.
- Expose repository contracts as the only package-level data access boundary.
- Let the app layer depend on app-owned adapters under `app.infra.storage`, not on storage DB implementation classes.
- Allow the app/gateway migration to happen in small steps without forcing a large rewrite.
## Non-Goals
- This design does not remove legacy persistence in the first PR.
- This design does not move routers directly onto storage package models.
- This design does not make app routers own SQLAlchemy sessions.
- Cron persistence is intentionally out of scope for the storage package foundation.
- Memory backend is not part of the durable storage package. Memory compatibility, if still needed by app runtime, belongs outside `packages/storage`.
## Storage Design Principles
### Package-Owned Durable Storage
`packages/storage` owns durable application data persistence. It defines:
- configuration shape for storage-backed persistence
- SQLAlchemy models
- repository contracts and DTOs
- SQL repository implementations
- persistence factory functions
- compatibility helpers for config-driven initialization
The package should be usable without importing `app.gateway`, routers, auth providers, or runtime-specific gateway objects.
### SQL Backend Compatibility
The package supports three SQL backends:
- SQLite for local/single-node deployments
- PostgreSQL for production multi-node deployments
- MySQL for deployments that standardize on MySQL
Backend-specific differences are handled inside the storage package:
- SQLAlchemy async engine URL construction
- LangGraph checkpointer connection-string compatibility
- JSON metadata filtering across SQLite/PostgreSQL/MySQL
- SQL dialect behavior around locking, aggregation, and JSON type semantics
### Unified Persistence Bundle
Storage initialization returns an `AppPersistence` bundle:
```python
@dataclass(slots=True)
class AppPersistence:
checkpointer: Checkpointer
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
setup: Callable[[], Awaitable[None]]
aclose: Callable[[], Awaitable[None]]
```
The app runtime can initialize persistence once, call `setup()`, and then inject:
- `checkpointer`
- `session_factory`
- repository adapters
This keeps checkpointer and application data aligned to the same backend without requiring routers to understand database configuration.
## Package Layout
```text
backend/packages/storage/
store/
config/
storage_config.py
app_config.py
persistence/
factory.py
types.py
base_model.py
json_compat.py
drivers/
sqlite.py
postgres.py
mysql.py
repositories/
contracts/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
models/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
db/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
factory.py
```
## Persistence Construction
The primary storage entrypoint is:
```python
from store.persistence import create_persistence_from_storage_config
persistence = await create_persistence_from_storage_config(storage_config)
await persistence.setup()
```
For app-level compatibility with existing database config shape:
```python
from store.persistence import create_persistence_from_database_config
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
```
Expected app startup flow:
```python
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
app.state.persistence = persistence
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
Expected app shutdown flow:
```python
await app.state.persistence.aclose()
```
## Repository Contract Design
Repository contracts are the storage package's public data access boundary. They live under `store.repositories.contracts` and are re-exported from `store.repositories`.
The key contract groups are:
- `UserRepositoryProtocol`
- `RunRepositoryProtocol`
- `ThreadMetaRepositoryProtocol`
- `FeedbackRepositoryProtocol`
- `RunEventRepositoryProtocol`
Each contract owns:
- input DTOs, such as `UserCreate`, `RunCreate`, `ThreadMetaCreate`
- output DTOs, such as `User`, `Run`, `ThreadMeta`
- repository protocol methods
- domain-specific exceptions when needed, such as `InvalidMetadataFilterError`
Repository construction is session-based:
```python
from store.repositories import build_run_repository
async with persistence.session_factory() as session:
repo = build_run_repository(session)
run = await repo.get_run(run_id)
```
This keeps transaction ownership explicit. The storage package does not hide commits or session lifecycle inside global singletons.
## App/Infra Calling Contract
The app layer should not call `store.repositories.db.*` directly. The intended app boundary is `app.infra.storage`.
`app.infra.storage` is responsible for:
- receiving `session_factory` from FastAPI runtime initialization
- owning session lifecycle for app-facing repository methods
- translating storage DTOs to app/gateway DTOs only when needed
- preserving the existing app-facing names during migration
- depending on storage repository protocols, not concrete DB classes
Expected adapter pattern:
```python
class StorageRunRepository(RunRepositoryProtocol):
def __init__(self, session_factory):
self._session_factory = session_factory
async def get_run(self, run_id: str):
async with self._session_factory() as session:
repo = build_run_repository(session)
return await repo.get_run(run_id)
```
For gateway compatibility, app state can keep existing names while the implementation changes:
```python
app.state.run_store = StorageRunStore(run_repository)
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
app.state.run_event_store = StorageRunEventStore(run_event_repository)
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
The app-facing objects may expose legacy method names during migration, but their internal data access should go through storage contracts.
## Boundary Rules
### Allowed Calls
Storage package callers may use:
```python
from store.persistence import create_persistence_from_database_config
from store.persistence import create_persistence_from_storage_config
from store.repositories import build_run_repository
from store.repositories import build_user_repository
from store.repositories import build_thread_meta_repository
from store.repositories import build_feedback_repository
from store.repositories import build_run_event_repository
from store.repositories import RunRepositoryProtocol
from store.repositories import UserRepositoryProtocol
```
App layer callers should use:
```python
from app.infra.storage import StorageRunRepository
from app.infra.storage import StorageUserDataRepository
from app.infra.storage import StorageThreadMetaRepository
from app.infra.storage import StorageFeedbackRepository
from app.infra.storage import StorageRunEventRepository
```
### Prohibited Calls
App/gateway/router/auth code must not import:
```python
from store.repositories.db import DbRunRepository
from store.repositories.models import Run
from store.persistence.base_model import MappedBase
```
Routers must not:
- create SQLAlchemy engines
- create SQLAlchemy sessions directly
- call storage DB repository classes directly
- commit/rollback storage transactions directly unless explicitly scoped by an infra adapter
- depend on storage SQLAlchemy model classes
Storage package code must not import:
```python
import app.gateway
import app.infra
import deerflow.runtime
```
The dependency direction is:
```text
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
```
The reverse direction is forbidden.
## Checkpointer Compatibility
The storage persistence bundle initializes the LangGraph checkpointer alongside application data persistence.
Backend-specific notes:
- SQLite uses `langgraph-checkpoint-sqlite`.
- PostgreSQL uses `langgraph-checkpoint-postgres` and requires a string `postgresql://...` connection URL.
- MySQL uses `langgraph-checkpoint-mysql` and requires a string MySQL connection URL.
SQLAlchemy may use async driver URLs such as `postgresql+asyncpg://...` or `mysql+aiomysql://...`, but LangGraph checkpointer constructors expect plain string connection URLs. This conversion belongs inside the storage driver implementation.
## JSON Metadata Filtering
Thread metadata search supports dialect-aware JSON filtering through `store.persistence.json_compat`.
The matcher supports:
- `None`
- `bool`
- `int`
- `float`
- `str`
It rejects:
- unsafe keys
- nested JSON path expressions
- dict/list values
- integers outside signed 64-bit range
This prevents SQL/JSON path injection, avoids compiled-cache type drift, and preserves type semantics such as `True != 1` and explicit JSON `null` not matching a missing key.
## Step-by-Step Implementation Plan
### Step 1: Introduce Storage Package Foundation
- Add `backend/packages/storage`.
- Add storage config models.
- Add `AppPersistence`.
- Add SQLite/PostgreSQL/MySQL persistence drivers.
- Add repository contracts, models, DB implementations, and factory helpers.
- Add package dependency wiring.
- Exclude cron persistence.
### Step 2: Harden Storage Backend Compatibility
- Validate SQLite setup and repository behavior.
- Validate PostgreSQL and MySQL with local E2E tests.
- Fix checkpointer connection-string compatibility.
- Fix PostgreSQL locking and aggregation differences.
- Add dialect-aware JSON metadata filtering.
### Step 3: Add App Infra Adapters
- Add `backend/app/infra/storage`.
- Implement app-facing repositories that own session lifecycle.
- Keep storage contracts as the only data access boundary.
- Add legacy compatibility adapters for existing app/gateway method shapes.
- Keep app/gateway imports out of `packages/storage`.
### Step 4: Switch FastAPI Runtime Injection
- Initialize storage persistence in FastAPI startup/lifespan.
- Attach `persistence`, `checkpointer`, and `session_factory` to `app.state`.
- Preserve existing external state names:
- `run_store`
- `feedback_repo`
- `thread_store`
- `run_event_store`
- `checkpointer`
- `session_factory`
- Start with user/auth provider construction, then migrate run/thread/feedback/run_event.
### Step 5: Router and Auth Compatibility
- Ensure routers consume app-facing adapters, not storage DB classes.
- Ensure auth providers depend on user repository contracts.
- Keep router response shapes unchanged.
- Add focused auth/admin/router regression tests.
### Step 6: Cleanup Legacy Persistence
- Compare old persistence usage after app/gateway migration.
- Remove unused old repository implementations only after all call sites move.
- Keep compatibility shims only where needed for a transition window.
- Delete memory backend paths from storage-owned durable persistence.
## Testing Strategy
Unit tests should cover:
- config parsing
- persistence setup
- table creation
- repository CRUD/query behavior
- typed JSON metadata filtering
- dialect SQL compilation
- cron exclusion
E2E tests should cover:
- SQLite persistence setup
- PostgreSQL temporary database setup
- MySQL temporary database setup
- repository contract behavior across all supported SQL backends
- JSON/Unicode round trip
- rollback behavior
- persistence close/cleanup
E2E tests may remain local-only if CI does not provide PostgreSQL/MySQL services.
+401
View File
@@ -0,0 +1,401 @@
# Storage Package 设计文档
## 背景
DeerFlow 当前有多类持久化职责分散在 app、gateway、runtime 和旧 persistence 模块中。这会带来几个问题:
- routers 和 runtime services 容易依赖具体 persistence 实现,而不是稳定契约。
- user/auth、run metadata、thread metadata、feedback、run events、checkpointer setup 的初始化路径不统一。
- memory、SQLite、PostgreSQL 相关路径中存在部分重复逻辑。
- app 层代码和 storage 层代码耦合,导致增量迁移困难。
- 增加或验证新的 SQL backend 时,需要改动 app/runtime,而不是只改 storage package。
引入 storage package 的目标,是把应用数据持久化抽象成 package 级能力,并提供明确契约、清晰边界和 SQL backend 兼容性。
## 目标
- 新增独立的 `packages/storage`,负责 durable application data。
- 通过统一 persistence 构造流程支持 SQLite、PostgreSQL、MySQL。
- 保持 LangGraph checkpointer 与同一个数据库 backend 兼容。
- 将 repository contracts 作为 package 对外唯一数据访问边界。
- app 层通过 `app.infra.storage` 适配 storage,而不是直接依赖 storage DB 实现类。
- 支持 app/gateway 后续小步迁移,避免一次性大重构。
## 非目标
- 第一阶段不删除旧 persistence。
- 不让 routers 直接依赖 storage package models。
- 不让 app routers 管理 SQLAlchemy sessions。
- cron persistence 不属于 storage package 基础迁移范围。
- memory backend 不属于 durable storage package。若 app runtime 仍需要 memory 兼容,应放在 `packages/storage` 之外。
## Storage 设计理念
### Package 自己负责 Durable Storage
`packages/storage` 负责应用数据的 durable persistence,包括:
- storage 持久化配置
- SQLAlchemy models
- repository contracts 和 DTOs
- SQL repository 实现
- persistence factory functions
- 面向现有 config 的兼容初始化入口
该 package 不应该 import `app.gateway`、routers、auth providers 或 runtime 中的 gateway 对象。
### SQL Backend 兼容
该 package 支持三种 SQL backend
- SQLite:本地或单节点部署
- PostgreSQL:生产多节点部署
- MySQL:使用 MySQL 作为标准数据库的部署
backend 差异在 storage package 内部处理:
- SQLAlchemy async engine URL 构造
- LangGraph checkpointer 连接串兼容
- SQLite/PostgreSQL/MySQL 的 JSON metadata filter
- 不同 SQL 方言在 locking、aggregation、JSON 类型语义上的差异
### 统一 Persistence Bundle
Storage 初始化返回 `AppPersistence` bundle
```python
@dataclass(slots=True)
class AppPersistence:
checkpointer: Checkpointer
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
setup: Callable[[], Awaitable[None]]
aclose: Callable[[], Awaitable[None]]
```
app runtime 只需要初始化一次 persistence,调用 `setup()`,然后注入:
- `checkpointer`
- `session_factory`
- repository adapters
这样 checkpointer 和应用数据可以对齐到同一个 backend,同时 routers 不需要理解数据库配置。
## Package 结构
```text
backend/packages/storage/
store/
config/
storage_config.py
app_config.py
persistence/
factory.py
types.py
base_model.py
json_compat.py
drivers/
sqlite.py
postgres.py
mysql.py
repositories/
contracts/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
models/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
db/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
factory.py
```
## Persistence 构造
storage 的主要入口:
```python
from store.persistence import create_persistence_from_storage_config
persistence = await create_persistence_from_storage_config(storage_config)
await persistence.setup()
```
为了兼容现有 app database config,也提供:
```python
from store.persistence import create_persistence_from_database_config
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
```
预期 app startup 流程:
```python
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
app.state.persistence = persistence
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
预期 app shutdown 流程:
```python
await app.state.persistence.aclose()
```
## Repository 契约设计
Repository contracts 是 storage package 对外公开的数据访问边界。它们位于 `store.repositories.contracts`,并通过 `store.repositories` re-export。
主要契约包括:
- `UserRepositoryProtocol`
- `RunRepositoryProtocol`
- `ThreadMetaRepositoryProtocol`
- `FeedbackRepositoryProtocol`
- `RunEventRepositoryProtocol`
每组契约包含:
- 输入 DTO,例如 `UserCreate``RunCreate``ThreadMetaCreate`
- 输出 DTO,例如 `User``Run``ThreadMeta`
- repository protocol methods
- 必要的领域异常,例如 `InvalidMetadataFilterError`
Repository 通过 session 构造:
```python
from store.repositories import build_run_repository
async with persistence.session_factory() as session:
repo = build_run_repository(session)
run = await repo.get_run(run_id)
```
这样可以让 transaction ownership 保持明确。storage package 不通过全局 singleton 隐式隐藏 commit 或 session 生命周期。
## App/Infra 调用契约
app 层不应该直接调用 `store.repositories.db.*`。预期的 app 边界是 `app.infra.storage`
`app.infra.storage` 负责:
- 从 FastAPI runtime 初始化中接收 `session_factory`
- 为 app-facing repository methods 管理 session 生命周期
- 在必要时将 storage DTOs 转成 app/gateway DTOs
- 迁移期间保留现有 app-facing 名称
- 依赖 storage repository protocols,而不是具体 DB classes
预期 adapter 模式:
```python
class StorageRunRepository(RunRepositoryProtocol):
def __init__(self, session_factory):
self._session_factory = session_factory
async def get_run(self, run_id: str):
async with self._session_factory() as session:
repo = build_run_repository(session)
return await repo.get_run(run_id)
```
为了兼容 gatewayapp state 可以暂时保持现有名字,只替换内部实现:
```python
app.state.run_store = StorageRunStore(run_repository)
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
app.state.run_event_store = StorageRunEventStore(run_event_repository)
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
app-facing objects 可以在迁移期间保留旧方法名,但内部数据访问必须经过 storage contracts。
## 边界规则
### 允许调用的范围
storage package 调用方可以使用:
```python
from store.persistence import create_persistence_from_database_config
from store.persistence import create_persistence_from_storage_config
from store.repositories import build_run_repository
from store.repositories import build_user_repository
from store.repositories import build_thread_meta_repository
from store.repositories import build_feedback_repository
from store.repositories import build_run_event_repository
from store.repositories import RunRepositoryProtocol
from store.repositories import UserRepositoryProtocol
```
app 层应该使用:
```python
from app.infra.storage import StorageRunRepository
from app.infra.storage import StorageUserDataRepository
from app.infra.storage import StorageThreadMetaRepository
from app.infra.storage import StorageFeedbackRepository
from app.infra.storage import StorageRunEventRepository
```
### 禁止调用的范围
app/gateway/router/auth 代码不应该 import
```python
from store.repositories.db import DbRunRepository
from store.repositories.models import Run
from store.persistence.base_model import MappedBase
```
routers 禁止:
- 创建 SQLAlchemy engines
- 直接创建 SQLAlchemy sessions
- 直接调用 storage DB repository classes
- 直接 commit/rollback storage transactions,除非这是 infra adapter 明确管理的范围
- 依赖 storage SQLAlchemy model classes
storage package 禁止 import
```python
import app.gateway
import app.infra
import deerflow.runtime
```
依赖方向必须是:
```text
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
```
禁止反向依赖。
## Checkpointer 兼容
storage persistence bundle 会同时初始化 LangGraph checkpointer 和应用数据持久化。
backend 说明:
- SQLite 使用 `langgraph-checkpoint-sqlite`
- PostgreSQL 使用 `langgraph-checkpoint-postgres`,需要字符串形式的 `postgresql://...` 连接串。
- MySQL 使用 `langgraph-checkpoint-mysql`,需要字符串形式的 MySQL 连接串。
SQLAlchemy 可以使用 `postgresql+asyncpg://...``mysql+aiomysql://...` 这类 async driver URL,但 LangGraph checkpointer 构造函数需要普通字符串连接串。这个转换应该封装在 storage driver implementation 内部。
## JSON Metadata Filtering
Thread metadata search 通过 `store.persistence.json_compat` 支持跨方言 JSON filtering。
支持的 filter value 类型:
- `None`
- `bool`
- `int`
- `float`
- `str`
拒绝:
- unsafe keys
- nested JSON path expressions
- dict/list values
- 超出 signed 64-bit 范围的整数
这样可以避免 SQL/JSON path injection,避免 compiled-cache 类型漂移,并保留类型语义,例如 `True != 1`,显式 JSON `null` 不等于 missing key。
## 分步实现方案
### 第 1 步:新增 Storage Package 基础
- 新增 `backend/packages/storage`
- 增加 storage config models。
- 增加 `AppPersistence`
- 增加 SQLite/PostgreSQL/MySQL persistence drivers。
- 增加 repository contracts、models、DB implementations 和 factory helpers。
- 接入 package dependency。
- 排除 cron persistence。
### 第 2 步:补齐 Storage Backend 兼容性
- 验证 SQLite setup 和 repository 行为。
- 使用本地 E2E 验证 PostgreSQL 和 MySQL。
- 修复 checkpointer 连接串兼容。
- 修复 PostgreSQL locking 和 aggregation 差异。
- 增加跨方言 JSON metadata filtering。
### 第 3 步:新增 App Infra Adapters
- 新增 `backend/app/infra/storage`
- 实现 app-facing repositories,由它们管理 session 生命周期。
- 保持 storage contracts 作为唯一数据访问边界。
- 为现有 app/gateway method shape 增加兼容 adapters。
- 避免 `packages/storage` import app/gateway。
### 第 4 步:切换 FastAPI Runtime 注入
- 在 FastAPI startup/lifespan 中初始化 storage persistence。
-`persistence``checkpointer``session_factory` 注入 `app.state`
- 暂时保留现有对外 state 名称:
- `run_store`
- `feedback_repo`
- `thread_store`
- `run_event_store`
- `checkpointer`
- `session_factory`
- 先切 user/auth provider 构造,再逐步迁移 run/thread/feedback/run_event。
### 第 5 步:Router 和 Auth 兼容
- 确保 routers 消费 app-facing adapters,而不是 storage DB classes。
- 确保 auth providers 依赖 user repository contracts。
- 保持 router response shapes 不变。
- 增加 auth/admin/router regression tests。
### 第 6 步:清理旧 Persistence
- app/gateway 迁移完成后,再比较旧 persistence usage。
- 所有 call sites 迁移完成后,再删除未使用的旧 repository implementations。
- 只在必要时保留短期 compatibility shims。
- 从 storage-owned durable persistence 中移除 memory backend 路径。
## 测试策略
单测应覆盖:
- config parsing
- persistence setup
- table creation
- repository CRUD/query behavior
- typed JSON metadata filtering
- dialect SQL compilation
- cron exclusion
E2E 应覆盖:
- SQLite persistence setup
- PostgreSQL temporary database setup
- MySQL temporary database setup
- 所有支持 SQL backend 下的 repository contract 行为
- JSON/Unicode round trip
- rollback behavior
- persistence close/cleanup
如果 CI 暂时没有 PostgreSQL/MySQL servicesE2E 可以先作为 local-only 验证保留。
@@ -40,6 +40,15 @@ class MemoryUpdateQueue:
self._timer: threading.Timer | None = None
self._processing = False
@staticmethod
def _queue_key(
thread_id: str,
user_id: str | None,
agent_name: str | None,
) -> tuple[str, str | None, str | None]:
"""Return the debounce identity for a memory update target."""
return (thread_id, user_id, agent_name)
def add(
self,
thread_id: str,
@@ -115,8 +124,9 @@ class MemoryUpdateQueue:
correction_detected: bool,
reinforcement_detected: bool,
) -> None:
queue_key = self._queue_key(thread_id, user_id, agent_name)
existing_context = next(
(context for context in self._queue if context.thread_id == thread_id),
(context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) == queue_key),
None,
)
merged_correction_detected = correction_detected or (existing_context.correction_detected if existing_context is not None else False)
@@ -130,7 +140,7 @@ class MemoryUpdateQueue:
reinforcement_detected=merged_reinforcement_detected,
)
self._queue = [c for c in self._queue if c.thread_id != thread_id]
self._queue = [context for context in self._queue if self._queue_key(context.thread_id, context.user_id, context.agent_name) != queue_key]
self._queue.append(context)
def _reset_timer(self) -> None:
@@ -6,6 +6,7 @@ from deerflow.agents.memory.message_processing import detect_correction, detect_
from deerflow.agents.memory.queue import get_memory_queue
from deerflow.agents.middlewares.summarization_middleware import SummarizationEvent
from deerflow.config.memory_config import get_memory_config
from deerflow.runtime.user_context import resolve_runtime_user_id
def memory_flush_hook(event: SummarizationEvent) -> None:
@@ -21,11 +22,13 @@ def memory_flush_hook(event: SummarizationEvent) -> None:
correction_detected = detect_correction(filtered_messages)
reinforcement_detected = not correction_detected and detect_reinforcement(filtered_messages)
user_id = resolve_runtime_user_id(event.runtime)
queue = get_memory_queue()
queue.add_nowait(
thread_id=event.thread_id,
messages=filtered_messages,
agent_name=event.agent_name,
user_id=user_id,
correction_detected=correction_detected,
reinforcement_detected=reinforcement_detected,
)
@@ -9,7 +9,7 @@ from typing import Any, override
from langchain.agents import AgentState
from langchain.agents.middleware import AgentMiddleware
from langchain.agents.middleware.todo import Todo
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, ToolMessage
from langgraph.runtime import Runtime
logger = logging.getLogger(__name__)
@@ -217,6 +217,17 @@ def _infer_step_kind(message: AIMessage, actions: list[dict[str, Any]]) -> str:
return "thinking"
def _has_tool_call(message: AIMessage, tool_call_id: str) -> bool:
"""Return True if the AIMessage contains a tool_call with the given id."""
for tc in message.tool_calls or []:
if isinstance(tc, dict):
if tc.get("id") == tool_call_id:
return True
elif hasattr(tc, "id") and tc.id == tool_call_id:
return True
return False
def _build_attribution(message: AIMessage, todos: list[Todo]) -> dict[str, Any]:
tool_calls = getattr(message, "tool_calls", None) or []
actions: list[dict[str, Any]] = []
@@ -261,8 +272,51 @@ class TokenUsageMiddleware(AgentMiddleware):
if not messages:
return None
# Annotate subagent token usage onto the AIMessage that dispatched it.
# When a task tool completes, its usage is cached by tool_call_id. Detect
# the ToolMessage → search backward for the corresponding AIMessage → merge.
# Walk backward through consecutive ToolMessages before the new AIMessage
# so that multiple concurrent task tool calls all get their subagent tokens
# written back to the same dispatch message (merging into one update).
state_updates: dict[int, AIMessage] = {}
if len(messages) >= 2:
from deerflow.tools.builtins.task_tool import pop_cached_subagent_usage
idx = len(messages) - 2
while idx >= 0:
tool_msg = messages[idx]
if not isinstance(tool_msg, ToolMessage) or not tool_msg.tool_call_id:
break
subagent_usage = pop_cached_subagent_usage(tool_msg.tool_call_id)
if subagent_usage:
# Search backward from the ToolMessage to find the AIMessage
# that dispatched it. A single model response can dispatch
# multiple task tool calls, so we can't assume a fixed offset.
dispatch_idx = idx - 1
while dispatch_idx >= 0:
candidate = messages[dispatch_idx]
if isinstance(candidate, AIMessage) and _has_tool_call(candidate, tool_msg.tool_call_id):
# Accumulate into an existing update for the same
# AIMessage (multiple task calls in one response),
# or merge fresh from the original message.
existing_update = state_updates.get(dispatch_idx)
prev = existing_update.usage_metadata if existing_update else (getattr(candidate, "usage_metadata", None) or {})
merged = {
**prev,
"input_tokens": prev.get("input_tokens", 0) + subagent_usage["input_tokens"],
"output_tokens": prev.get("output_tokens", 0) + subagent_usage["output_tokens"],
"total_tokens": prev.get("total_tokens", 0) + subagent_usage["total_tokens"],
}
state_updates[dispatch_idx] = candidate.model_copy(update={"usage_metadata": merged})
break
dispatch_idx -= 1
idx -= 1
last = messages[-1]
if not isinstance(last, AIMessage):
if state_updates:
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
return None
usage = getattr(last, "usage_metadata", None)
@@ -288,11 +342,12 @@ class TokenUsageMiddleware(AgentMiddleware):
additional_kwargs = dict(getattr(last, "additional_kwargs", {}) or {})
if additional_kwargs.get(TOKEN_USAGE_ATTRIBUTION_KEY) == attribution:
return None
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]} if state_updates else None
additional_kwargs[TOKEN_USAGE_ATTRIBUTION_KEY] = attribution
updated_msg = last.model_copy(update={"additional_kwargs": additional_kwargs})
return {"messages": [updated_msg]}
state_updates[len(messages) - 1] = updated_msg
return {"messages": [state_updates[idx] for idx in sorted(state_updates)]}
@override
def after_model(self, state: AgentState, runtime: Runtime) -> dict | None:
@@ -223,10 +223,11 @@ class RunRepository(RunStore):
"""Aggregate token usage via a single SQL GROUP BY query."""
_completed = RunRow.status.in_(("success", "error"))
_thread = RunRow.thread_id == thread_id
model_name = func.coalesce(RunRow.model_name, "unknown")
stmt = (
select(
func.coalesce(RunRow.model_name, "unknown").label("model"),
model_name.label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunRow.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunRow.total_input_tokens), 0).label("total_input_tokens"),
@@ -236,7 +237,7 @@ class RunRepository(RunStore):
func.coalesce(func.sum(RunRow.middleware_tokens), 0).label("middleware"),
)
.where(_thread, _completed)
.group_by(func.coalesce(RunRow.model_name, "unknown"))
.group_by(model_name)
)
async with self._sf() as session:
@@ -11,7 +11,7 @@ import logging
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy import delete, func, select, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from deerflow.persistence.models.run_event import RunEventRow
@@ -86,6 +86,28 @@ class DbRunEventStore(RunEventStore):
user = get_current_user()
return str(user.id) if user is not None else None
@staticmethod
async def _max_seq_for_thread(session: AsyncSession, thread_id: str) -> int | None:
"""Return the current max seq while serializing writers per thread.
PostgreSQL rejects ``SELECT max(...) FOR UPDATE`` because aggregate
results are not lockable rows. As a release-safe workaround, take a
transaction-level advisory lock keyed by thread_id before reading the
aggregate. Other dialects keep the existing row-locking statement.
"""
stmt = select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id)
bind = session.get_bind()
dialect_name = bind.dialect.name if bind is not None else ""
if dialect_name == "postgresql":
await session.execute(
text("SELECT pg_advisory_xact_lock(hashtext(CAST(:thread_id AS text))::bigint)"),
{"thread_id": thread_id},
)
return await session.scalar(stmt)
return await session.scalar(stmt.with_for_update())
async def put(self, *, thread_id, run_id, event_type, category, content="", metadata=None, created_at=None): # noqa: D401
"""Write a single event — low-frequency path only.
@@ -100,10 +122,7 @@ class DbRunEventStore(RunEventStore):
user_id = self._user_id_from_context()
async with self._sf() as session:
async with session.begin():
# Use FOR UPDATE to serialize seq assignment within a thread.
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
max_seq = await self._max_seq_for_thread(session, thread_id)
seq = (max_seq or 0) + 1
row = RunEventRow(
thread_id=thread_id,
@@ -126,10 +145,8 @@ class DbRunEventStore(RunEventStore):
async with self._sf() as session:
async with session.begin():
# Get max seq for the thread (assume all events in batch belong to same thread).
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
# the UNIQUE(thread_id, seq) constraint catches races there.
thread_id = events[0]["thread_id"]
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
max_seq = await self._max_seq_for_thread(session, thread_id)
seq = max_seq or 0
rows = []
for e in events:
@@ -26,6 +26,28 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# Cache subagent token usage by tool_call_id so TokenUsageMiddleware can
# write it back to the triggering AIMessage's usage_metadata.
_subagent_usage_cache: dict[str, dict[str, int]] = {}
def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
if app_config is None:
try:
app_config = get_app_config()
except (FileNotFoundError, ValueError):
return False
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
def _cache_subagent_usage(tool_call_id: str, usage: dict | None, *, enabled: bool = True) -> None:
if enabled and usage:
_subagent_usage_cache[tool_call_id] = usage
def pop_cached_subagent_usage(tool_call_id: str) -> dict | None:
return _subagent_usage_cache.pop(tool_call_id, None)
def _is_subagent_terminal(result: Any) -> bool:
"""Return whether a background subagent result is safe to clean up."""
@@ -92,6 +114,17 @@ def _find_usage_recorder(runtime: Any) -> Any | None:
return None
def _summarize_usage(records: list[dict] | None) -> dict | None:
"""Summarize token usage records into a compact dict for SSE events."""
if not records:
return None
return {
"input_tokens": sum(r.get("input_tokens", 0) or 0 for r in records),
"output_tokens": sum(r.get("output_tokens", 0) or 0 for r in records),
"total_tokens": sum(r.get("total_tokens", 0) or 0 for r in records),
}
def _report_subagent_usage(runtime: Any, result: Any) -> None:
"""Report subagent token usage to the parent RunJournal, if available.
@@ -177,6 +210,7 @@ async def task_tool(
subagent_type: The type of subagent to use. ALWAYS PROVIDE THIS PARAMETER THIRD.
"""
runtime_app_config = _get_runtime_app_config(runtime)
cache_token_usage = _token_usage_cache_enabled(runtime_app_config)
available_subagent_names = get_available_subagent_names(app_config=runtime_app_config) if runtime_app_config is not None else get_available_subagent_names()
# Get subagent configuration
@@ -312,27 +346,32 @@ async def task_tool(
last_message_count = current_message_count
# Check if task completed, failed, or timed out
usage = _summarize_usage(getattr(result, "token_usage_records", None))
if result.status == SubagentStatus.COMPLETED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_completed", "task_id": task_id, "result": result.result})
writer({"type": "task_completed", "task_id": task_id, "result": result.result, "usage": usage})
logger.info(f"[trace={trace_id}] Task {task_id} completed after {poll_count} polls")
cleanup_background_task(task_id)
return f"Task Succeeded. Result: {result.result}"
elif result.status == SubagentStatus.FAILED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_failed", "task_id": task_id, "error": result.error})
writer({"type": "task_failed", "task_id": task_id, "error": result.error, "usage": usage})
logger.error(f"[trace={trace_id}] Task {task_id} failed: {result.error}")
cleanup_background_task(task_id)
return f"Task failed. Error: {result.error}"
elif result.status == SubagentStatus.CANCELLED:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error})
writer({"type": "task_cancelled", "task_id": task_id, "error": result.error, "usage": usage})
logger.info(f"[trace={trace_id}] Task {task_id} cancelled: {result.error}")
cleanup_background_task(task_id)
return "Task cancelled by user."
elif result.status == SubagentStatus.TIMED_OUT:
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
_report_subagent_usage(runtime, result)
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error})
writer({"type": "task_timed_out", "task_id": task_id, "error": result.error, "usage": usage})
logger.warning(f"[trace={trace_id}] Task {task_id} timed out: {result.error}")
cleanup_background_task(task_id)
return f"Task timed out. Error: {result.error}"
@@ -351,7 +390,9 @@ async def task_tool(
timeout_minutes = config.timeout_seconds // 60
logger.error(f"[trace={trace_id}] Task {task_id} polling timed out after {poll_count} polls (should have been caught by thread pool timeout)")
_report_subagent_usage(runtime, result)
writer({"type": "task_timed_out", "task_id": task_id})
usage = _summarize_usage(getattr(result, "token_usage_records", None))
_cache_subagent_usage(tool_call_id, usage, enabled=cache_token_usage)
writer({"type": "task_timed_out", "task_id": task_id, "usage": usage})
return f"Task polling timed out after {timeout_minutes} minutes. This may indicate the background task is stuck. Status: {result.status.value}"
except asyncio.CancelledError:
# Signal the background subagent thread to stop cooperatively.
@@ -374,4 +415,8 @@ async def task_tool(
cleanup_background_task(task_id)
else:
_schedule_deferred_subagent_cleanup(task_id, trace_id, max_poll_count)
_subagent_usage_cache.pop(tool_call_id, None)
raise
except Exception:
_subagent_usage_cache.pop(tool_call_id, None)
raise
@@ -7,7 +7,7 @@ from deerflow.config.app_config import AppConfig
from deerflow.reflection import resolve_variable
from deerflow.sandbox.security import is_host_bash_allowed
from deerflow.tools.builtins import ask_clarification_tool, present_file_tool, task_tool, view_image_tool
from deerflow.tools.builtins.tool_search import reset_deferred_registry
from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.sync import make_sync_tool_wrapper
logger = logging.getLogger(__name__)
@@ -116,8 +116,6 @@ def get_available_tools(
# made through the Gateway API (which runs in a separate process) are immediately
# reflected when loading MCP tools.
mcp_tools = []
# Reset deferred registry upfront to prevent stale state from previous calls
reset_deferred_registry()
if include_mcp:
try:
from deerflow.config.extensions_config import ExtensionsConfig
@@ -135,12 +133,51 @@ def get_available_tools(
from deerflow.tools.builtins.tool_search import DeferredToolRegistry, set_deferred_registry
from deerflow.tools.builtins.tool_search import tool_search as tool_search_tool
registry = DeferredToolRegistry()
for t in mcp_tools:
registry.register(t)
set_deferred_registry(registry)
# Reuse the existing registry if one is already set for
# this async context. ``get_available_tools`` is
# re-entered whenever a subagent is spawned
# (``task_tool`` calls it to build the child agent's
# toolset), and previously we used to unconditionally
# rebuild the registry — wiping out the parent agent's
# tool_search promotions. The
# ``DeferredToolFilterMiddleware`` then re-hid those
# tools from subsequent model calls, leaving the agent
# able to see a tool's name but unable to invoke it
# (issue #2884). ``contextvars`` already gives us the
# lifetime semantics we want: a fresh request / graph
# run starts in a new asyncio task with the
# ContextVar at its default of ``None``, so reuse is
# only triggered for re-entrant calls inside one run.
#
# Intentionally NOT reconciling against the current
# ``mcp_tools`` snapshot. The MCP cache only refreshes
# on ``extensions_config.json`` mtime changes, which
# in practice happens between graph runs — not inside
# one. And even if a refresh did happen mid-run, the
# already-built lead agent's ``ToolNode`` still holds
# the *previous* tool set (LangGraph binds tools at
# graph construction time), so a brand-new MCP tool
# couldn't actually be invoked anyway. The
# ``DeferredToolRegistry`` doesn't retain the names
# of previously-promoted tools (``promote()`` drops
# the entry entirely), so re-syncing the registry
# against a fresh ``mcp_tools`` list would
# mis-classify those promotions as new tools and
# re-register them as deferred — exactly the bug
# this fix exists to prevent.
existing_registry = get_deferred_registry()
if existing_registry is None:
registry = DeferredToolRegistry()
for t in mcp_tools:
registry.register(t)
set_deferred_registry(registry)
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
else:
mcp_tool_names = {t.name for t in mcp_tools}
still_deferred = len(existing_registry)
promoted_count = max(0, len(mcp_tool_names) - still_deferred)
logger.info(f"Tool search active (preserved promotions): {still_deferred} tools deferred, {promoted_count} already promoted")
builtin_tools.append(tool_search_tool)
logger.info(f"Tool search active: {len(mcp_tools)} tools deferred")
except ImportError:
logger.warning("MCP module not available. Install 'langchain-mcp-adapters' package to enable MCP tools.")
except Exception as e:
+35
View File
@@ -0,0 +1,35 @@
[project]
name = "deerflow-storage"
version = "0.1.0"
description = "DeerFlow storage framework"
requires-python = ">=3.12"
dependencies = [
"dotenv>=0.9.9",
"pydantic>=2.12.5",
"pyyaml>=6.0.3",
"sqlalchemy[asyncio]>=2.0,<3.0",
"alembic>=1.13",
"langgraph>=1.1.9",
]
[project.optional-dependencies]
postgres = [
"asyncpg>=0.29",
"langgraph-checkpoint-postgres>=3.0.5",
"psycopg[binary]>=3.3.3",
"psycopg-pool>=3.3.0",
]
mysql = [
"aiomysql>=0.2",
"langgraph-checkpoint-mysql>=3.0.0",
]
sqlite = [
"aiosqlite>=0.22.1",
"langgraph-checkpoint-sqlite>=3.0.3"
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["store"]
@@ -0,0 +1,5 @@
from .enums import DataBaseType
__all__ = [
"DataBaseType",
]
@@ -0,0 +1,41 @@
from enum import Enum
from enum import IntEnum as SourceIntEnum
from enum import StrEnum as SourceStrEnum
from typing import Any, TypeVar
T = TypeVar("T", bound=Enum)
class _EnumBase:
"""Base enum class with common utility methods."""
@classmethod
def get_member_keys(cls) -> list[str]:
"""Return a list of enum member names."""
return list(cls.__members__.keys())
@classmethod
def get_member_values(cls) -> list:
"""Return a list of enum member values."""
return [item.value for item in cls.__members__.values()]
@classmethod
def get_member_dict(cls) -> dict[str, Any]:
"""Return a dict mapping member names to values."""
return {name: item.value for name, item in cls.__members__.items()}
class IntEnum(_EnumBase, SourceIntEnum):
"""Integer enum base class."""
class StrEnum(_EnumBase, SourceStrEnum):
"""String enum base class."""
class DataBaseType(StrEnum):
"""Database type."""
sqlite = "sqlite"
mysql = "mysql"
postgresql = "postgresql"
@@ -0,0 +1,286 @@
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from store.config.storage_config import StorageConfig
load_dotenv()
logger = logging.getLogger(__name__)
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
cwd = Path.cwd().resolve()
candidates = (
cwd / "config.yaml",
backend_dir / "config.yaml",
repo_root / "config.yaml",
)
return tuple(dict.fromkeys(candidates))
def _storage_from_database_config(config_data: dict[str, Any]) -> None:
"""Keep the existing public `database:` config compatible with storage."""
if "storage" in config_data:
return
database = config_data.get("database")
if not isinstance(database, dict):
return
backend = database.get("backend")
if backend == "memory":
raise ValueError("database.backend='memory' is not supported by storage; handle memory mode before loading storage config")
storage: dict[str, Any] = {
"driver": "postgres" if backend == "postgres" else backend,
"sqlite_dir": database.get("sqlite_dir", ".deer-flow/data"),
"echo_sql": database.get("echo_sql", False),
"pool_size": database.get("pool_size", 5),
}
postgres_url = database.get("postgres_url")
if backend == "postgres" and isinstance(postgres_url, str) and postgres_url:
from sqlalchemy.engine.url import make_url
parsed = make_url(postgres_url)
storage["database_url"] = postgres_url
storage.update(
{
"username": parsed.username or "",
"password": parsed.password or "",
"host": parsed.host or "localhost",
"port": parsed.port or 5432,
"db_name": parsed.database or "deerflow",
}
)
config_data["storage"] = storage
class AppConfig(BaseModel):
"""DeerFlow application configuration."""
timezone: str = Field(default="UTC", description="Timezone for scheduling and timestamps (e.g. 'UTC', 'America/New_York')")
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
storage: StorageConfig = Field(default=StorageConfig())
model_config = ConfigDict(extra="allow", frozen=False)
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path:
"""Resolve the config file path.
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
"""
if config_path:
path = Path(config_path)
if not Path.exists(path):
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
return path
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
if not Path.exists(path):
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
return path
else:
for path in _default_config_candidates():
if path.exists():
return path
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
@classmethod
def from_file(cls, config_path: str | None = None) -> Self:
"""Load and validate config from YAML. See `resolve_config_path` for path resolution."""
resolved_path = cls.resolve_config_path(config_path)
with open(resolved_path, encoding="utf-8") as f:
config_data = yaml.safe_load(f) or {}
cls._check_config_version(config_data, resolved_path)
config_data = cls.resolve_env_variables(config_data)
_storage_from_database_config(config_data)
if os.getenv("TIMEZONE"):
config_data["timezone"] = os.getenv("TIMEZONE")
result = cls.model_validate(config_data)
return result
@classmethod
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
Emits a warning if the user's config_version is lower than the example's.
Missing config_version is treated as version 0 (pre-versioning).
"""
try:
user_version = int(config_data.get("config_version", 0))
except (TypeError, ValueError):
user_version = 0
# Find config.example.yaml by searching config.yaml's directory and its parents
example_path = None
search_dir = config_path.parent
for _ in range(5): # search up to 5 levels
candidate = search_dir / "config.example.yaml"
if candidate.exists():
example_path = candidate
break
parent = search_dir.parent
if parent == search_dir:
break
search_dir = parent
if example_path is None:
return
try:
with open(example_path, encoding="utf-8") as f:
example_data = yaml.safe_load(f)
raw = example_data.get("config_version", 0) if example_data else 0
try:
example_version = int(raw)
except (TypeError, ValueError):
example_version = 0
except Exception:
return
if user_version < example_version:
logger.warning(
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.",
user_version,
example_version,
)
@classmethod
def resolve_env_variables(cls, config: Any) -> Any:
"""Recursively replace $VAR strings with their environment variable values (e.g. $OPENAI_API_KEY)."""
if isinstance(config, str):
if config.startswith("$"):
env_value = os.getenv(config[1:])
if env_value is None:
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
return env_value
return config
elif isinstance(config, dict):
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
elif isinstance(config, list):
return [cls.resolve_env_variables(item) for item in config]
return config
_app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
def _get_config_mtime(config_path: Path) -> float | None:
"""Get the modification time of a config file if it exists."""
try:
return config_path.stat().st_mtime
except OSError:
return None
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
"""Load config from disk and refresh cache metadata."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
resolved_path = AppConfig.resolve_config_path(config_path)
_app_config = AppConfig.from_file(str(resolved_path))
_app_config_path = resolved_path
_app_config_mtime = _get_config_mtime(resolved_path)
_app_config_is_custom = False
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Force reload from file and update the cache."""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Clear the cache so the next `get_app_config()` reloads from file."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Inject a config instance directly, bypassing file loading (for testing)."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
@@ -0,0 +1,69 @@
"""Unified storage backend configuration for checkpointer and application data.
SQLite: checkpointer {sqlite_dir}/checkpoints.db, app {sqlite_dir}/deerflow.db
(separate files to avoid write-lock contention)
Postgres: shared URL, independent connection pools per layer.
Sensitive values use $VAR syntax resolved by AppConfig.resolve_env_variables()
before this config is instantiated.
"""
from __future__ import annotations
import os
from typing import Literal
from pydantic import BaseModel, Field
def _strip_legacy_state_prefix(path: str) -> str:
"""Keep old .deer-flow/* config values compatible with Paths.base_dir."""
prefix = ".deer-flow/"
if path == ".deer-flow":
return "."
if path.startswith(prefix):
return path[len(prefix) :]
return path
class StorageConfig(BaseModel):
driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field(
default="sqlite",
description="Storage driver for both checkpointer and application data. 'sqlite' for single-node deployment (default),'postgres' for production multi-node deployment, 'mysql' for MySQL databases.",
)
sqlite_dir: str = Field(
default=".deer-flow/data",
description="Directory for SQLite .db files (sqlite driver only).",
)
username: str = Field(default="", description="db username ")
password: str = Field(default="", description="db password. Use $VAR syntax in config.yaml to read from .env.")
host: str = Field(default="localhost", description="db host.")
port: int = Field(default=5432, description="db port.")
db_name: str = Field(default="deerflow", description="db database name.")
database_url: str = Field(default="", description="Complete SQLAlchemy database URL. Takes precedence for non-SQLite drivers.")
sqlite_db_path: str = Field(default=".deer-flow/data", description="Directory for SQLite .db files (sqlite driver only).")
echo_sql: bool = Field(default=False, description="Log all SQL statements (debug only).")
pool_size: int = Field(default=5, description="Connection pool size per layer.")
# -- Derived helpers (not user-configured) --
@property
def _resolved_sqlite_dir(self) -> str:
"""Resolve sqlite_dir to an absolute path under DeerFlow's base dir."""
from pathlib import Path
path = Path(self.sqlite_dir)
if path.is_absolute():
return str(path.resolve())
try:
from deerflow.config.paths import resolve_path
return str(resolve_path(_strip_legacy_state_prefix(self.sqlite_dir)))
except ImportError:
return str(path.resolve())
@property
def sqlite_storage_path(self) -> str:
"""SQLite file path for storage-owned app data and checkpointer."""
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
@@ -0,0 +1,32 @@
from store.persistence.base_model import (
Base,
DataClassBase,
DateTimeMixin,
MappedBase,
TimeZone,
UniversalText,
id_key,
)
from .factory import (
create_persistence,
create_persistence_from_database_config,
create_persistence_from_storage_config,
storage_config_from_database_config,
)
from .types import AppPersistence
__all__ = [
"Base",
"DataClassBase",
"DateTimeMixin",
"MappedBase",
"TimeZone",
"UniversalText",
"id_key",
"create_persistence",
"create_persistence_from_database_config",
"create_persistence_from_storage_config",
"storage_config_from_database_config",
"AppPersistence",
]
@@ -0,0 +1,111 @@
from datetime import datetime
from typing import Annotated
from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
from store.utils import get_timezone
def current_time() -> datetime:
return get_timezone().now()
id_key = Annotated[
int,
mapped_column(
BigInteger().with_variant(Integer, "sqlite"),
primary_key=True,
unique=True,
index=True,
autoincrement=True,
sort_order=-999,
comment="Primary key ID",
),
]
class UniversalText(TypeDecorator[str]):
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
impl = Text
cache_ok = True
def load_dialect_impl(self, dialect): # noqa: ANN001
if dialect.name == "mysql":
return dialect.type_descriptor(LONGTEXT())
return dialect.type_descriptor(Text())
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
return value
def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001
return value
class TimeZone(TypeDecorator[datetime]):
"""Timezone-aware datetime type compatible with PostgreSQL and MySQL."""
impl = DateTime(timezone=True)
cache_ok = True
@property
def python_type(self) -> type[datetime]:
return datetime
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.utcoffset() != timezone.now().utcoffset():
value = timezone.from_datetime(value)
return value
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.tzinfo is None:
value = value.replace(tzinfo=timezone.tz_info)
return value
class DateTimeMixin(MappedAsDataclass):
"""Mixin that adds created_time / updated_time columns."""
created_time: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
TimeZone,
init=False,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
class MappedBase(AsyncAttrs, DeclarativeBase):
"""Async-capable declarative base for all ORM models."""
@declared_attr.directive
def __tablename__(self) -> str:
return self.__name__.lower()
@declared_attr.directive
def __table_args__(self) -> dict:
return {"comment": self.__doc__ or ""}
class DataClassBase(MappedAsDataclass, MappedBase):
"""Declarative base with native dataclass integration."""
__abstract__ = True
class Base(DataClassBase, DateTimeMixin):
"""Declarative dataclass base with created_time / updated_time columns."""
__abstract__ = True
@@ -0,0 +1,9 @@
from .mysql import build_mysql_persistence
from .postgres import build_postgres_persistence
from .sqlite import build_sqlite_persistence
__all__ = [
"build_postgres_persistence",
"build_mysql_persistence",
"build_sqlite_persistence",
]
@@ -0,0 +1,76 @@
from __future__ import annotations
import json
from sqlalchemy import URL
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.persistence.shared import close_in_order
from store.persistence.types import AppPersistence
def _validate_mysql_driver(db_url: URL) -> str:
url = make_url(db_url)
driver = url.get_driver_name()
if driver not in {"aiomysql", "asyncmy"}:
raise ValueError(f"MySQL persistence requires async SQLAlchemy driver (aiomysql/asyncmy), got: {driver!r}")
return driver
def _checkpoint_conn_string(db_url: URL) -> str:
return db_url.render_as_string(hide_password=False)
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
_validate_mysql_driver(db_url)
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
import store.repositories.models # noqa: F401
engine = create_async_engine(
db_url,
echo=echo,
future=True,
pool_pre_ping=True,
pool_size=pool_size,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
checkpointer = await saver_cm.__aenter__()
async def setup() -> None:
# 1. LangGraph checkpoint tables / migrations
await checkpointer.setup()
# 2. ORM business tables
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
async def _close_saver() -> None:
await saver_cm.__aexit__(None, None, None)
async def aclose() -> None:
await close_in_order(
engine.dispose,
_close_saver,
)
return AppPersistence(
checkpointer=checkpointer,
engine=engine,
session_factory=session_factory,
setup=setup,
aclose=aclose,
)
@@ -0,0 +1,64 @@
from __future__ import annotations
import json
from sqlalchemy import URL
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.persistence.shared import close_in_order
from store.persistence.types import AppPersistence
def _checkpoint_conn_string(db_url: URL) -> str:
return db_url.set(drivername="postgresql").render_as_string(hide_password=False)
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
import store.repositories.models # noqa: F401
engine = create_async_engine(
db_url,
echo=echo,
future=True,
pool_pre_ping=True,
pool_size=pool_size,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url))
checkpointer = await saver_cm.__aenter__()
async def setup() -> None:
# 1. LangGraph checkpoint tables / migrations
await checkpointer.setup()
# 2. ORM business tables
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
async def _close_saver() -> None:
await saver_cm.__aexit__(None, None, None)
async def aclose() -> None:
await close_in_order(
engine.dispose,
_close_saver,
)
return AppPersistence(
checkpointer=checkpointer,
engine=engine,
session_factory=session_factory,
setup=setup,
aclose=aclose,
)
@@ -0,0 +1,68 @@
from __future__ import annotations
import json
from sqlalchemy import URL, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.persistence.shared import close_in_order
from store.persistence.types import AppPersistence
async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
import store.repositories.models # noqa: F401
engine = create_async_engine(
db_url,
echo=echo,
future=True,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
@event.listens_for(engine.sync_engine, "connect")
def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001
cursor = dbapi_conn.cursor()
try:
cursor.execute("PRAGMA journal_mode=WAL;")
cursor.execute("PRAGMA synchronous=NORMAL;")
cursor.execute("PRAGMA foreign_keys=ON;")
finally:
cursor.close()
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
checkpointer = await saver_cm.__aenter__()
async def setup() -> None:
# 1. LangGraph checkpoint tables
await checkpointer.setup()
# 2. ORM business tables
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
async def _close_saver() -> None:
await saver_cm.__aexit__(None, None, None)
async def aclose() -> None:
await close_in_order(
engine.dispose,
_close_saver,
)
return AppPersistence(
checkpointer=checkpointer,
engine=engine,
session_factory=session_factory,
setup=setup,
aclose=aclose,
)
@@ -0,0 +1,123 @@
from typing import Any
from sqlalchemy import URL
from sqlalchemy.engine.url import make_url
from store.common import DataBaseType
from store.config.app_config import get_app_config
from store.config.storage_config import StorageConfig
from store.persistence.types import AppPersistence
def storage_config_from_database_config(database_config: Any) -> StorageConfig:
"""Convert the existing public DatabaseConfig shape to StorageConfig.
Storage only owns durable database-backed persistence. The app bridge
should handle memory mode before calling into this package.
"""
backend = getattr(database_config, "backend", None)
if backend == "sqlite":
return StorageConfig(
driver="sqlite",
sqlite_dir=getattr(database_config, "sqlite_dir", ".deer-flow/data"),
echo_sql=getattr(database_config, "echo_sql", False),
pool_size=getattr(database_config, "pool_size", 5),
)
if backend == "postgres":
postgres_url = getattr(database_config, "postgres_url", "")
if not postgres_url:
raise ValueError("database.postgres_url is required when database.backend is 'postgres'")
parsed = make_url(postgres_url)
return StorageConfig(
driver="postgres",
database_url=postgres_url,
username=parsed.username or "",
password=parsed.password or "",
host=parsed.host or "localhost",
port=parsed.port or 5432,
db_name=parsed.database or "deerflow",
echo_sql=getattr(database_config, "echo_sql", False),
pool_size=getattr(database_config, "pool_size", 5),
)
raise ValueError(f"Unsupported database backend for storage persistence: {backend!r}")
def _create_database_url(storage_config: StorageConfig) -> URL:
"""Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres)."""
if storage_config.driver == DataBaseType.sqlite:
driver = "sqlite+aiosqlite"
elif storage_config.driver == DataBaseType.mysql:
driver = "mysql+aiomysql"
elif storage_config.driver in (DataBaseType.postgresql, "postgres"):
driver = "postgresql+asyncpg"
else:
raise ValueError(f"Unsupported database driver: {storage_config.driver}")
if storage_config.driver == DataBaseType.sqlite:
import os
db_path = storage_config.sqlite_storage_path
os.makedirs(os.path.dirname(db_path), exist_ok=True)
url = URL.create(
drivername=driver,
database=db_path,
)
elif storage_config.database_url:
url = make_url(storage_config.database_url)
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
url = url.set(drivername="postgresql+asyncpg")
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
url = url.set(drivername="mysql+aiomysql")
else:
url = URL.create(
drivername=driver,
username=storage_config.username,
password=storage_config.password,
host=storage_config.host,
port=storage_config.port,
database=storage_config.db_name or "deerflow",
)
return url
async def create_persistence_from_storage_config(storage_config: StorageConfig) -> AppPersistence:
from .drivers.mysql import build_mysql_persistence
from .drivers.postgres import build_postgres_persistence
from .drivers.sqlite import build_sqlite_persistence
driver = storage_config.driver
db_url = _create_database_url(storage_config)
if driver in ("postgres", "postgresql"):
return await build_postgres_persistence(
db_url,
echo=storage_config.echo_sql,
pool_size=storage_config.pool_size,
)
if driver == "mysql":
return await build_mysql_persistence(
db_url,
echo=storage_config.echo_sql,
pool_size=storage_config.pool_size,
)
if driver == "sqlite":
return await build_sqlite_persistence(db_url, echo=storage_config.echo_sql)
raise ValueError(f"Unsupported database driver: {driver}")
async def create_persistence_from_database_config(database_config: Any) -> AppPersistence:
storage_config = storage_config_from_database_config(database_config)
return await create_persistence_from_storage_config(storage_config)
async def create_persistence() -> AppPersistence:
app_config = get_app_config()
return await create_persistence_from_storage_config(app_config.storage)
@@ -0,0 +1,189 @@
"""Dialect-aware JSON value matching for storage SQLAlchemy repositories."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
from sqlalchemy import BigInteger, Float, String, bindparam
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.types import Boolean, TypeEngine
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
_INT64_MIN = -(2**63)
_INT64_MAX = 2**63 - 1
def validate_metadata_filter_key(key: object) -> bool:
"""Return True when *key* is safe for JSON metadata filter SQL paths."""
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
def validate_metadata_filter_value(value: object) -> bool:
"""Return True when *value* can be compiled into a portable JSON predicate."""
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
return False
if isinstance(value, int) and not isinstance(value, bool):
return _INT64_MIN <= value <= _INT64_MAX
return True
class JsonMatch(ColumnElement[bool]):
"""Dialect-portable ``column[key] == value`` for JSON columns."""
inherit_cache = True
type = Boolean()
_is_implicitly_boolean = True
_traverse_internals = [
("column", InternalTraversal.dp_clauseelement),
("key", InternalTraversal.dp_string),
("value", InternalTraversal.dp_plain_obj),
("value_type", InternalTraversal.dp_string),
]
def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None:
if not validate_metadata_filter_key(key):
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
if not validate_metadata_filter_value(value):
if isinstance(value, int) and not isinstance(value, bool):
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
self.column = column
self.key = key
self.value = value
self.value_type = type(value).__qualname__
super().__init__()
@dataclass(frozen=True)
class _Dialect:
null_type: str
num_types: tuple[str, ...]
num_cast: str
int_types: tuple[str, ...]
int_cast: str
int_guard: str | None
string_type: str
bool_type: str | None
true_value: str
false_value: str
_SQLITE = _Dialect(
null_type="null",
num_types=("integer", "real"),
num_cast="REAL",
int_types=("integer",),
int_cast="INTEGER",
int_guard=None,
string_type="text",
bool_type=None,
true_value="true",
false_value="false",
)
_POSTGRES = _Dialect(
null_type="null",
num_types=("number",),
num_cast="DOUBLE PRECISION",
int_types=("number",),
int_cast="BIGINT",
int_guard="'^-?[0-9]+$'",
string_type="string",
bool_type="boolean",
true_value="true",
false_value="false",
)
_MYSQL = _Dialect(
null_type="NULL",
num_types=("INTEGER", "DOUBLE", "DECIMAL"),
num_cast="DOUBLE",
int_types=("INTEGER",),
int_cast="SIGNED",
int_guard=None,
string_type="STRING",
bool_type="BOOLEAN",
true_value="true",
false_value="false",
)
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
param = bindparam(None, value, type_=sa_type)
return compiler.process(param, **kw)
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
if len(types) == 1:
return f"{typeof} = '{types[0]}'"
quoted = ", ".join(f"'{type_name}'" for type_name in types)
return f"{typeof} IN ({quoted})"
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
if value is None:
return f"{typeof} = '{dialect.null_type}'"
if isinstance(value, bool):
bool_str = dialect.true_value if value else dialect.false_value
if dialect.bool_type is None:
return f"{typeof} = '{bool_str}'"
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
if isinstance(value, int):
bp = _bind(compiler, value, BigInteger(), **kw)
if dialect.int_guard:
return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
if isinstance(value, float):
bp = _bind(compiler, value, Float(), **kw)
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
bp = _bind(compiler, str(value), String(), **kw)
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
@compiles(JsonMatch, "sqlite")
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
path = f'$."{element.key}"'
typeof = f"json_type({col}, '{path}')"
extract = f"json_extract({col}, '{path}')"
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
@compiles(JsonMatch, "postgresql")
def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
typeof = f"json_typeof({col} -> '{element.key}')"
extract = f"({col} ->> '{element.key}')"
return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw)
@compiles(JsonMatch, "mysql")
def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
path = f'$."{element.key}"'
typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))"
extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))"
return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw)
@compiles(JsonMatch)
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}")
def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch:
return JsonMatch(column, key, value)
@@ -0,0 +1,3 @@
from .close import close_in_order
__all__ = ["close_in_order"]
@@ -0,0 +1,28 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
AsyncCloser = Callable[[], Awaitable[None]]
async def close_in_order(*closers: AsyncCloser) -> None:
"""
Run async closers in order and raise the first error, if any.
Notes
-----
- Used to keep driver-specific close logic readable.
- We intentionally do not stop at first failure, so later resources
still get a chance to close.
"""
first_error: Exception | None = None
for closer in closers:
try:
await closer()
except Exception as exc:
if first_error is None:
first_error = exc
if first_error is not None:
raise first_error
@@ -0,0 +1,23 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
AsyncSetup = Callable[[], Awaitable[None]]
AsyncClose = Callable[[], Awaitable[None]]
@dataclass(slots=True)
class AppPersistence:
"""
Unified runtime persistence bundle.
"""
checkpointer: Checkpointer
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
setup: AsyncSetup
aclose: AsyncClose
@@ -0,0 +1,53 @@
from store.repositories.contracts import (
Feedback,
FeedbackAggregate,
FeedbackCreate,
FeedbackRepositoryProtocol,
InvalidMetadataFilterError,
Run,
RunCreate,
RunEvent,
RunEventCreate,
RunEventRepositoryProtocol,
RunRepositoryProtocol,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
User,
UserCreate,
UserNotFoundError,
UserRepositoryProtocol,
)
from store.repositories.factory import (
build_feedback_repository,
build_run_event_repository,
build_run_repository,
build_thread_meta_repository,
build_user_repository,
)
__all__ = [
"Feedback",
"FeedbackAggregate",
"FeedbackCreate",
"FeedbackRepositoryProtocol",
"InvalidMetadataFilterError",
"Run",
"RunCreate",
"RunEvent",
"RunEventCreate",
"RunEventRepositoryProtocol",
"RunRepositoryProtocol",
"ThreadMeta",
"ThreadMetaCreate",
"ThreadMetaRepositoryProtocol",
"User",
"UserCreate",
"UserNotFoundError",
"UserRepositoryProtocol",
"build_run_repository",
"build_run_event_repository",
"build_thread_meta_repository",
"build_feedback_repository",
"build_user_repository",
]
@@ -0,0 +1,49 @@
from store.repositories.contracts.feedback import (
Feedback,
FeedbackAggregate,
FeedbackCreate,
FeedbackRepositoryProtocol,
)
from store.repositories.contracts.run import (
Run,
RunCreate,
RunRepositoryProtocol,
)
from store.repositories.contracts.run_event import (
RunEvent,
RunEventCreate,
RunEventRepositoryProtocol,
)
from store.repositories.contracts.thread_meta import (
InvalidMetadataFilterError,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
)
from store.repositories.contracts.user import (
User,
UserCreate,
UserNotFoundError,
UserRepositoryProtocol,
)
__all__ = [
"Feedback",
"FeedbackAggregate",
"FeedbackCreate",
"FeedbackRepositoryProtocol",
"Run",
"RunCreate",
"RunEvent",
"RunEventCreate",
"RunEventRepositoryProtocol",
"RunRepositoryProtocol",
"InvalidMetadataFilterError",
"ThreadMeta",
"ThreadMetaCreate",
"ThreadMetaRepositoryProtocol",
"User",
"UserCreate",
"UserNotFoundError",
"UserRepositoryProtocol",
]
@@ -0,0 +1,77 @@
from __future__ import annotations
from datetime import datetime
from typing import Protocol, TypedDict
from pydantic import BaseModel, ConfigDict
class FeedbackCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
feedback_id: str
run_id: str
thread_id: str
rating: int
user_id: str | None = None
message_id: str | None = None
comment: str | None = None
class Feedback(BaseModel):
model_config = ConfigDict(frozen=True)
feedback_id: str
run_id: str
thread_id: str
rating: int
user_id: str | None
message_id: str | None
comment: str | None
created_time: datetime
class FeedbackAggregate(TypedDict):
run_id: str
total: int
positive: int
negative: int
class FeedbackRepositoryProtocol(Protocol):
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
pass
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
pass
async def get_feedback(self, feedback_id: str) -> Feedback | None:
pass
async def list_feedback_by_run(
self,
run_id: str,
*,
thread_id: str | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
pass
async def list_feedback_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
pass
async def delete_feedback(self, feedback_id: str) -> bool:
pass
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
pass
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
pass
@@ -0,0 +1,100 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Protocol
from pydantic import BaseModel, ConfigDict, Field
class RunCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
run_id: str
thread_id: str
assistant_id: str | None = None
user_id: str | None = None
status: str = "pending"
model_name: str | None = None
multitask_strategy: str = "reject"
error: str | None = None
follow_up_to_run_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
kwargs: dict[str, Any] = Field(default_factory=dict)
created_time: datetime | None = None
class Run(BaseModel):
model_config = ConfigDict(frozen=True)
run_id: str
thread_id: str
assistant_id: str | None
user_id: str | None
status: str
model_name: str | None
multitask_strategy: str
error: str | None
follow_up_to_run_id: str | None
metadata: dict[str, Any]
kwargs: dict[str, Any]
total_input_tokens: int
total_output_tokens: int
total_tokens: int
llm_call_count: int
lead_agent_tokens: int
subagent_tokens: int
middleware_tokens: int
message_count: int
first_human_message: str | None
last_ai_message: str | None
created_time: datetime
updated_time: datetime | None
class RunRepositoryProtocol(Protocol):
async def create_run(self, data: RunCreate) -> Run:
pass
async def get_run(self, run_id: str) -> Run | None:
pass
async def list_runs_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[Run]:
pass
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
pass
async def delete_run(self, run_id: str) -> None:
pass
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
pass
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
first_human_message: str | None = None,
last_ai_message: str | None = None,
error: str | None = None,
) -> None:
pass
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
pass
@@ -0,0 +1,83 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Protocol
from pydantic import BaseModel, ConfigDict, Field
class RunEventCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str
run_id: str
user_id: str | None = None
event_type: str
category: str
content: Any = ""
metadata: dict[str, Any] = Field(default_factory=dict)
created_at: datetime | None = None
class RunEvent(BaseModel):
model_config = ConfigDict(frozen=True)
thread_id: str
run_id: str
user_id: str | None
event_type: str
category: str
content: Any
metadata: dict[str, Any]
seq: int
created_at: datetime
class RunEventRepositoryProtocol(Protocol):
# Sequence values are time-ordered integer cursors. The application layer
# owns the single-writer invariant for a thread while a run is active.
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
pass
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
pass
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
user_id: str | None = None,
) -> list[RunEvent]:
pass
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
pass
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
pass
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
pass
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
pass
@@ -0,0 +1,67 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Protocol
from pydantic import BaseModel, ConfigDict, Field
class InvalidMetadataFilterError(ValueError):
"""Raised when all client-supplied metadata filters are rejected."""
class ThreadMetaCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str
assistant_id: str | None = None
user_id: str | None = None
display_name: str | None = None
status: str = "idle"
metadata: dict[str, Any] = Field(default_factory=dict)
class ThreadMeta(BaseModel):
model_config = ConfigDict(frozen=True)
thread_id: str
assistant_id: str | None
user_id: str | None
display_name: str | None
status: str
metadata: dict[str, Any]
created_time: datetime
updated_time: datetime | None
class ThreadMetaRepositoryProtocol(Protocol):
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
pass
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
pass
async def update_thread_meta(
self,
thread_id: str,
*,
display_name: str | None = None,
status: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
pass
async def delete_thread(self, thread_id: str) -> None:
pass
async def search_threads(
self,
*,
metadata: dict[str, Any] | None = None,
status: str | None = None,
user_id: str | None = None,
assistant_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ThreadMeta]:
pass
@@ -0,0 +1,64 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal, Protocol
from pydantic import BaseModel, ConfigDict
class UserNotFoundError(LookupError):
"""Raised when an update targets a user row that no longer exists."""
class UserCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
id: str
email: str
password_hash: str | None = None
system_role: Literal["admin", "user"] = "user"
created_at: datetime | None = None
oauth_provider: str | None = None
oauth_id: str | None = None
needs_setup: bool = False
token_version: int = 0
class User(BaseModel):
model_config = ConfigDict(frozen=True)
id: str
email: str
password_hash: str | None
system_role: Literal["admin", "user"]
created_at: datetime
oauth_provider: str | None
oauth_id: str | None
needs_setup: bool
token_version: int
class UserRepositoryProtocol(Protocol):
async def create_user(self, data: UserCreate) -> User:
pass
async def get_user_by_id(self, user_id: str) -> User | None:
pass
async def get_user_by_email(self, email: str) -> User | None:
pass
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
pass
async def get_first_admin(self) -> User | None:
pass
async def update_user(self, data: User) -> User:
pass
async def count_users(self) -> int:
pass
async def count_admin_users(self) -> int:
pass
@@ -0,0 +1,13 @@
from store.repositories.db.feedback import DbFeedbackRepository
from store.repositories.db.run import DbRunRepository
from store.repositories.db.run_event import DbRunEventRepository
from store.repositories.db.thread_meta import DbThreadMetaRepository
from store.repositories.db.user import DbUserRepository
__all__ = [
"DbFeedbackRepository",
"DbRunRepository",
"DbRunEventRepository",
"DbThreadMetaRepository",
"DbUserRepository",
]
@@ -0,0 +1,142 @@
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import case, delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.feedback import Feedback, FeedbackAggregate, FeedbackCreate, FeedbackRepositoryProtocol
from store.repositories.models.feedback import Feedback as FeedbackModel
def _to_feedback(m: FeedbackModel) -> Feedback:
return Feedback(
feedback_id=m.feedback_id,
run_id=m.run_id,
thread_id=m.thread_id,
rating=m.rating,
user_id=m.user_id,
message_id=m.message_id,
comment=m.comment,
created_time=m.created_time,
)
class DbFeedbackRepository(FeedbackRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
if data.rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
model = FeedbackModel(
feedback_id=data.feedback_id,
run_id=data.run_id,
thread_id=data.thread_id,
rating=data.rating,
user_id=data.user_id,
message_id=data.message_id,
comment=data.comment,
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return _to_feedback(model)
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
if data.rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
result = await self._session.execute(
select(FeedbackModel).where(
FeedbackModel.thread_id == data.thread_id,
FeedbackModel.run_id == data.run_id,
FeedbackModel.user_id == data.user_id,
)
)
model = result.scalar_one_or_none()
if model is None:
return await self.create_feedback(data)
model.rating = data.rating
model.message_id = data.message_id
model.comment = data.comment
model.created_time = datetime.now(UTC)
await self._session.flush()
await self._session.refresh(model)
return _to_feedback(model)
async def get_feedback(self, feedback_id: str) -> Feedback | None:
result = await self._session.execute(select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
model = result.scalar_one_or_none()
return _to_feedback(model) if model else None
async def list_feedback_by_run(
self,
run_id: str,
*,
thread_id: str | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
stmt = select(FeedbackModel).where(FeedbackModel.run_id == run_id)
if thread_id is not None:
stmt = stmt.where(FeedbackModel.thread_id == thread_id)
if user_id is not None:
stmt = stmt.where(FeedbackModel.user_id == user_id)
stmt = stmt.order_by(FeedbackModel.created_time.desc())
if limit is not None:
stmt = stmt.limit(limit)
result = await self._session.execute(stmt)
return [_to_feedback(m) for m in result.scalars().all()]
async def list_feedback_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
stmt = select(FeedbackModel).where(FeedbackModel.thread_id == thread_id)
if user_id is not None:
stmt = stmt.where(FeedbackModel.user_id == user_id)
stmt = stmt.order_by(FeedbackModel.created_time.desc())
if limit is not None:
stmt = stmt.limit(limit)
result = await self._session.execute(stmt)
return [_to_feedback(m) for m in result.scalars().all()]
async def delete_feedback(self, feedback_id: str) -> bool:
existing = await self.get_feedback(feedback_id)
if existing is None:
return False
await self._session.execute(delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
return True
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
stmt = select(FeedbackModel).where(
FeedbackModel.thread_id == thread_id,
FeedbackModel.run_id == run_id,
)
if user_id is not None:
stmt = stmt.where(FeedbackModel.user_id == user_id)
result = await self._session.execute(stmt)
model = result.scalar_one_or_none()
if model is None:
return False
await self._session.delete(model)
return True
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
stmt = select(
func.count().label("total"),
func.coalesce(func.sum(case((FeedbackModel.rating == 1, 1), else_=0)), 0).label("positive"),
func.coalesce(func.sum(case((FeedbackModel.rating == -1, 1), else_=0)), 0).label("negative"),
).where(FeedbackModel.thread_id == thread_id, FeedbackModel.run_id == run_id)
row = (await self._session.execute(stmt)).one()
return {
"run_id": run_id,
"total": int(row.total),
"positive": int(row.positive),
"negative": int(row.negative),
}
@@ -0,0 +1,185 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol
from store.repositories.models.run import Run as RunModel
def _to_run(m: RunModel) -> Run:
return Run(
run_id=m.run_id,
thread_id=m.thread_id,
assistant_id=m.assistant_id,
user_id=m.user_id,
status=m.status,
model_name=m.model_name,
multitask_strategy=m.multitask_strategy,
error=m.error,
follow_up_to_run_id=m.follow_up_to_run_id,
metadata=dict(m.meta or {}),
kwargs=dict(m.kwargs or {}),
total_input_tokens=m.total_input_tokens,
total_output_tokens=m.total_output_tokens,
total_tokens=m.total_tokens,
llm_call_count=m.llm_call_count,
lead_agent_tokens=m.lead_agent_tokens,
subagent_tokens=m.subagent_tokens,
middleware_tokens=m.middleware_tokens,
message_count=m.message_count,
first_human_message=m.first_human_message,
last_ai_message=m.last_ai_message,
created_time=m.created_time,
updated_time=m.updated_time,
)
class DbRunRepository(RunRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_run(self, data: RunCreate) -> Run:
model = RunModel(
run_id=data.run_id,
thread_id=data.thread_id,
assistant_id=data.assistant_id,
user_id=data.user_id,
status=data.status,
model_name=data.model_name,
multitask_strategy=data.multitask_strategy,
error=data.error,
follow_up_to_run_id=data.follow_up_to_run_id,
meta=dict(data.metadata),
kwargs=dict(data.kwargs),
)
if data.created_time is not None:
model.created_time = data.created_time
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return _to_run(model)
async def get_run(self, run_id: str) -> Run | None:
result = await self._session.execute(select(RunModel).where(RunModel.run_id == run_id))
model = result.scalar_one_or_none()
return _to_run(model) if model else None
async def list_runs_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[Run]:
stmt = select(RunModel).where(RunModel.thread_id == thread_id)
if user_id is not None:
stmt = stmt.where(RunModel.user_id == user_id)
stmt = stmt.order_by(RunModel.created_time.desc()).limit(limit).offset(offset)
result = await self._session.execute(stmt)
return [_to_run(m) for m in result.scalars().all()]
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
values: dict = {"status": status}
if error is not None:
values["error"] = error
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
async def delete_run(self, run_id: str) -> None:
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
if before is None:
before_dt = datetime.now().astimezone()
elif isinstance(before, datetime):
before_dt = before
else:
before_dt = datetime.fromisoformat(before)
result = await self._session.execute(select(RunModel).where(RunModel.status == "pending", RunModel.created_time <= before_dt).order_by(RunModel.created_time.asc()))
return [_to_run(m) for m in result.scalars().all()]
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
first_human_message: str | None = None,
last_ai_message: str | None = None,
error: str | None = None,
) -> None:
values = {
"status": status,
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if error is not None:
values["error"] = error
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = RunModel.status.in_(("success", "error"))
model_expr = func.coalesce(RunModel.model_name, "unknown")
stmt = (
select(
model_expr.label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
func.coalesce(func.sum(RunModel.total_output_tokens), 0).label("total_output_tokens"),
func.coalesce(func.sum(RunModel.lead_agent_tokens), 0).label("lead_agent"),
func.coalesce(func.sum(RunModel.subagent_tokens), 0).label("subagent"),
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
)
.where(RunModel.thread_id == thread_id, completed)
.group_by(model_expr)
)
rows = (await self._session.execute(stmt)).all()
total_tokens = total_input = total_output = total_runs = 0
lead_agent = subagent = middleware = 0
by_model: dict[str, dict] = {}
for row in rows:
by_model[row.model] = {"tokens": row.total_tokens, "runs": row.runs}
total_tokens += row.total_tokens
total_input += row.total_input_tokens
total_output += row.total_output_tokens
total_runs += row.runs
lead_agent += row.lead_agent
subagent += row.subagent
middleware += row.middleware
return {
"total_tokens": total_tokens,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_runs": total_runs,
"by_model": by_model,
"by_caller": {
"lead_agent": lead_agent,
"subagent": subagent,
"middleware": middleware,
},
}
@@ -0,0 +1,207 @@
from __future__ import annotations
import json
import secrets
import threading
import time
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
from store.repositories.models.run_event import RunEvent as RunEventModel
_SEQ_COUNTER_BITS = 12
_SEQ_PROCESS_BITS = 9
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
class _SequenceAllocator:
def __init__(self) -> None:
self._last_millis = 0
self._lock = threading.Lock()
def allocate_base(self, batch_size: int) -> int:
if batch_size >= _SEQ_COUNTER_LIMIT:
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
now_ms = time.time_ns() // 1_000_000
with self._lock:
seq_ms = max(now_ms, self._last_millis + 1)
self._last_millis = seq_ms
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
_sequence_allocator = _SequenceAllocator()
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
if not isinstance(content, str):
next_metadata = {**metadata, "content_is_json": True}
if isinstance(content, dict):
next_metadata["content_is_dict"] = True
return json.dumps(content, default=str, ensure_ascii=False), next_metadata
return content, metadata
def _deserialize_content(content: str, metadata: dict[str, Any]) -> Any:
if not (metadata.get("content_is_json") or metadata.get("content_is_dict")):
return content
try:
return json.loads(content)
except json.JSONDecodeError:
return content
def _to_run_event(model: RunEventModel) -> RunEvent:
raw_metadata = dict(model.meta or {})
metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"}
return RunEvent(
thread_id=model.thread_id,
run_id=model.run_id,
user_id=model.user_id,
event_type=model.event_type,
category=model.category,
content=_deserialize_content(model.content, raw_metadata),
metadata=metadata,
seq=model.seq,
created_at=model.created_at,
)
class DbRunEventRepository(RunEventRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
if not events:
return []
seq_base = _sequence_allocator.allocate_base(len(events))
rows: list[RunEventModel] = []
for index, event in enumerate(events, start=1):
content, metadata = _serialize_content(event.content, dict(event.metadata))
row = RunEventModel(
thread_id=event.thread_id,
run_id=event.run_id,
user_id=event.user_id,
seq=seq_base + index,
event_type=event.event_type,
category=event.category,
content=content,
meta=metadata,
)
if event.created_at is not None:
row.created_at = event.created_at
self._session.add(row)
rows.append(row)
await self._session.flush()
return [_to_run_event(row) for row in rows]
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
stmt = select(RunEventModel).where(
RunEventModel.thread_id == thread_id,
RunEventModel.category == "message",
)
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
if before_seq is not None:
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
if after_seq is not None:
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
result = await self._session.execute(stmt)
return [_to_run_event(row) for row in result.scalars().all()]
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
user_id: str | None = None,
) -> list[RunEvent]:
stmt = select(RunEventModel).where(
RunEventModel.thread_id == thread_id,
RunEventModel.run_id == run_id,
)
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
if event_types is not None:
stmt = stmt.where(RunEventModel.event_type.in_(event_types))
stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit)
result = await self._session.execute(stmt)
return [_to_run_event(row) for row in result.scalars().all()]
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
stmt = select(RunEventModel).where(
RunEventModel.thread_id == thread_id,
RunEventModel.run_id == run_id,
RunEventModel.category == "message",
)
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
if before_seq is not None:
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
if after_seq is not None:
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
result = await self._session.execute(stmt)
return [_to_run_event(row) for row in result.scalars().all()]
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
stmt = select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
count = await self._session.scalar(stmt)
return int(count or 0)
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
conditions = [RunEventModel.thread_id == thread_id]
if user_id is not None:
conditions.append(RunEventModel.user_id == user_id)
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
await self._session.execute(delete(RunEventModel).where(*conditions))
return int(count or 0)
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
conditions = [RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id]
if user_id is not None:
conditions.append(RunEventModel.user_id == user_id)
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
await self._session.execute(delete(RunEventModel).where(*conditions))
return int(count or 0)
@@ -0,0 +1,113 @@
from __future__ import annotations
import logging
from typing import Any
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from store.persistence.json_compat import json_match
from store.repositories.contracts.thread_meta import (
InvalidMetadataFilterError,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
)
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
logger = logging.getLogger(__name__)
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
return ThreadMeta(
thread_id=m.thread_id,
assistant_id=m.assistant_id,
user_id=m.user_id,
display_name=m.display_name,
status=m.status,
metadata=dict(m.meta or {}),
created_time=m.created_time,
updated_time=m.updated_time,
)
class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
model = ThreadMetaModel(
thread_id=data.thread_id,
assistant_id=data.assistant_id,
user_id=data.user_id,
display_name=data.display_name,
status=data.status,
meta=dict(data.metadata),
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return _to_thread_meta(model)
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
model = result.scalar_one_or_none()
return _to_thread_meta(model) if model else None
async def update_thread_meta(
self,
thread_id: str,
*,
display_name: str | None = None,
status: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
values: dict = {}
if display_name is not None:
values["display_name"] = display_name
if status is not None:
values["status"] = status
if metadata is not None:
values["meta"] = dict(metadata)
if not values:
return
await self._session.execute(update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
async def delete_thread(self, thread_id: str) -> None:
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
async def search_threads(
self,
*,
metadata: dict[str, Any] | None = None,
status: str | None = None,
user_id: str | None = None,
assistant_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ThreadMeta]:
stmt = select(ThreadMetaModel)
if status is not None:
stmt = stmt.where(ThreadMetaModel.status == status)
if user_id is not None:
stmt = stmt.where(ThreadMetaModel.user_id == user_id)
if assistant_id is not None:
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
if metadata:
applied = 0
for key, value in metadata.items():
try:
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
applied += 1
except (ValueError, TypeError) as exc:
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
if applied == 0:
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
stmt = stmt.limit(limit).offset(offset)
result = await self._session.execute(stmt)
return [_to_thread_meta(m) for m in result.scalars().all()]
@@ -0,0 +1,98 @@
from __future__ import annotations
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.user import User, UserCreate, UserNotFoundError, UserRepositoryProtocol
from store.repositories.models.user import User as UserModel
def _to_user(model: UserModel) -> User:
return User(
id=model.id,
email=model.email,
password_hash=model.password_hash,
system_role=model.system_role, # type: ignore[arg-type]
created_at=model.created_at,
oauth_provider=model.oauth_provider,
oauth_id=model.oauth_id,
needs_setup=model.needs_setup,
token_version=model.token_version,
)
class DbUserRepository(UserRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_user(self, data: UserCreate) -> User:
model = UserModel(
id=data.id,
email=data.email,
system_role=data.system_role,
password_hash=data.password_hash,
oauth_provider=data.oauth_provider,
oauth_id=data.oauth_id,
needs_setup=data.needs_setup,
token_version=data.token_version,
)
if data.created_at is not None:
model.created_at = data.created_at
self._session.add(model)
try:
await self._session.flush()
except IntegrityError as exc:
await self._session.rollback()
raise ValueError(f"Email already registered: {data.email}") from exc
await self._session.refresh(model)
return _to_user(model)
async def get_user_by_id(self, user_id: str) -> User | None:
model = await self._session.get(UserModel, user_id)
return _to_user(model) if model is not None else None
async def get_user_by_email(self, email: str) -> User | None:
result = await self._session.execute(select(UserModel).where(UserModel.email == email))
model = result.scalar_one_or_none()
return _to_user(model) if model is not None else None
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
result = await self._session.execute(
select(UserModel).where(
UserModel.oauth_provider == provider,
UserModel.oauth_id == oauth_id,
)
)
model = result.scalar_one_or_none()
return _to_user(model) if model is not None else None
async def get_first_admin(self) -> User | None:
result = await self._session.execute(select(UserModel).where(UserModel.system_role == "admin").limit(1))
model = result.scalar_one_or_none()
return _to_user(model) if model is not None else None
async def update_user(self, data: User) -> User:
model = await self._session.get(UserModel, data.id)
if model is None:
raise UserNotFoundError(f"User {data.id} no longer exists")
model.email = data.email
model.password_hash = data.password_hash
model.system_role = data.system_role
model.oauth_provider = data.oauth_provider
model.oauth_id = data.oauth_id
model.needs_setup = data.needs_setup
model.token_version = data.token_version
await self._session.flush()
await self._session.refresh(model)
return _to_user(model)
async def count_users(self) -> int:
count = await self._session.scalar(select(func.count()).select_from(UserModel))
return int(count or 0)
async def count_admin_users(self) -> int:
count = await self._session.scalar(select(func.count()).select_from(UserModel).where(UserModel.system_role == "admin"))
return int(count or 0)
@@ -0,0 +1,36 @@
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories import (
FeedbackRepositoryProtocol,
RunEventRepositoryProtocol,
RunRepositoryProtocol,
ThreadMetaRepositoryProtocol,
UserRepositoryProtocol,
)
from store.repositories.db import (
DbFeedbackRepository,
DbRunEventRepository,
DbRunRepository,
DbThreadMetaRepository,
DbUserRepository,
)
def build_thread_meta_repository(session: AsyncSession) -> ThreadMetaRepositoryProtocol:
return DbThreadMetaRepository(session)
def build_run_repository(session: AsyncSession) -> RunRepositoryProtocol:
return DbRunRepository(session)
def build_feedback_repository(session: AsyncSession) -> FeedbackRepositoryProtocol:
return DbFeedbackRepository(session)
def build_run_event_repository(session: AsyncSession) -> RunEventRepositoryProtocol:
return DbRunEventRepository(session)
def build_user_repository(session: AsyncSession) -> UserRepositoryProtocol:
return DbUserRepository(session)
@@ -0,0 +1,7 @@
from store.repositories.models.feedback import Feedback
from store.repositories.models.run import Run
from store.repositories.models.run_event import RunEvent
from store.repositories.models.thread_meta import ThreadMeta
from store.repositories.models.user import User
__all__ = ["Feedback", "Run", "RunEvent", "ThreadMeta", "User"]
@@ -0,0 +1,36 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import Integer, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Feedback(DataClassBase):
"""Feedback table (create-only, no updated_time)."""
__tablename__ = "feedback"
__table_args__ = (
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
{"comment": "Feedback table."},
)
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
run_id: Mapped[str] = mapped_column(String(64), index=True)
thread_id: Mapped[str] = mapped_column(String(64), index=True)
rating: Mapped[int] = mapped_column(Integer)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
message_id: Mapped[str | None] = mapped_column(String(64), default=None)
comment: Mapped[str | None] = mapped_column(UniversalText, default=None)
created_time: Mapped[datetime] = mapped_column(
"created_at",
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -0,0 +1,63 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, Index, Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Run(DataClassBase):
"""Run metadata table."""
__tablename__ = "runs"
__table_args__ = (
Index("ix_runs_thread_status", "thread_id", "status"),
{"comment": "Run metadata table."},
)
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
thread_id: Mapped[str] = mapped_column(String(64), index=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
model_name: Mapped[str | None] = mapped_column(String(128), default=None)
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
error: Mapped[str | None] = mapped_column(UniversalText, default=None)
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64), default=None)
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
kwargs: Mapped[dict[str, Any]] = mapped_column("kwargs_json", JSON, default_factory=dict)
total_input_tokens: Mapped[int] = mapped_column(Integer, default=0)
total_output_tokens: Mapped[int] = mapped_column(Integer, default=0)
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
llm_call_count: Mapped[int] = mapped_column(Integer, default=0)
lead_agent_tokens: Mapped[int] = mapped_column(Integer, default=0)
subagent_tokens: Mapped[int] = mapped_column(Integer, default=0)
middleware_tokens: Mapped[int] = mapped_column(Integer, default=0)
message_count: Mapped[int] = mapped_column(Integer, default=0)
first_human_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
last_ai_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
created_time: Mapped[datetime] = mapped_column(
"created_at",
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
"updated_at",
TimeZone,
init=False,
default=None,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -0,0 +1,46 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import (
DataClassBase,
TimeZone,
UniversalText,
current_time,
id_key,
)
class RunEvent(DataClassBase):
"""Run event table."""
__tablename__ = "run_events"
__table_args__ = (
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
Index("ix_events_run", "thread_id", "run_id", "seq"),
{"comment": "Run event table."},
)
id: Mapped[id_key] = mapped_column(init=False)
thread_id: Mapped[str] = mapped_column(String(64), index=True)
run_id: Mapped[str] = mapped_column(String(64), index=True)
event_type: Mapped[str] = mapped_column(String(32), index=True)
category: Mapped[str] = mapped_column(String(16), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
content: Mapped[str] = mapped_column(UniversalText, default="")
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Event timestamp",
)
@@ -0,0 +1,43 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class ThreadMeta(DataClassBase):
"""Thread metadata table."""
__tablename__ = "threads_meta"
__table_args__ = {"comment": "Thread metadata table."}
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None, index=True)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
display_name: Mapped[str | None] = mapped_column(String(256), default=None)
status: Mapped[str] = mapped_column(String(20), default="idle", index=True)
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
created_time: Mapped[datetime] = mapped_column(
"created_at",
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
"updated_at",
TimeZone,
init=False,
default=None,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -0,0 +1,42 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import Boolean, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class User(DataClassBase):
"""User account table."""
__tablename__ = "users"
__table_args__ = (
Index(
"idx_users_oauth_identity",
"oauth_provider",
"oauth_id",
unique=True,
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
),
{"comment": "User account table."},
)
id: Mapped[str] = mapped_column(String(36), primary_key=True)
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
system_role: Mapped[str] = mapped_column(String(16), default="user")
password_hash: Mapped[str | None] = mapped_column(String(128), default=None)
oauth_provider: Mapped[str | None] = mapped_column(String(32), default=None)
oauth_id: Mapped[str | None] = mapped_column(String(128), default=None)
needs_setup: Mapped[bool] = mapped_column(Boolean, default=False)
token_version: Mapped[int] = mapped_column(default=0)
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -0,0 +1,3 @@
from .timezone import get_timezone
__all__ = ["get_timezone"]
@@ -0,0 +1,51 @@
import zoneinfo
from datetime import UTC, datetime
from store.config.app_config import get_app_config
# IANA identifiers that map to UTC — see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
_UTC_IDENTIFIERS = frozenset({"Etc/UCT", "Etc/Universal", "Etc/UTC", "Etc/Zulu", "UCT", "Universal", "UTC", "Zulu"})
class TimeZone:
def __init__(self) -> None:
app_config = get_app_config()
if app_config.timezone in _UTC_IDENTIFIERS:
self.tz_info = UTC
else:
self.tz_info = zoneinfo.ZoneInfo(app_config.timezone)
def now(self) -> datetime:
"""Return the current time in the configured timezone."""
return datetime.now(self.tz_info)
def from_datetime(self, t: datetime) -> datetime:
"""Convert a datetime to the configured timezone."""
return t.astimezone(self.tz_info)
def from_str(self, t_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime:
"""Parse a time string and attach the configured timezone."""
return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info)
@staticmethod
def to_str(t: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""Format a datetime to string."""
return t.strftime(format_str)
@staticmethod
def to_utc(t: datetime | int) -> datetime:
"""Convert a datetime or Unix timestamp to UTC."""
if isinstance(t, datetime):
return t.astimezone(UTC)
return datetime.fromtimestamp(t, tz=UTC)
_timezone = None
def get_timezone() -> TimeZone:
"""Return the global TimeZone singleton (lazy-initialized)."""
global _timezone
if _timezone is None:
_timezone = TimeZone()
return _timezone
+5 -2
View File
@@ -6,6 +6,7 @@ readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"deerflow-harness",
"deerflow-storage",
"fastapi>=0.115.0",
"httpx>=0.28.0",
"python-multipart>=0.0.27",
@@ -24,7 +25,8 @@ dependencies = [
]
[project.optional-dependencies]
postgres = ["deerflow-harness[postgres]"]
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
mysql = ["deerflow-storage[mysql]"]
[dependency-groups]
dev = [
@@ -43,7 +45,8 @@ markers = [
index-url = "https://pypi.org/simple"
[tool.uv.workspace]
members = ["packages/harness"]
members = ["packages/harness", "packages/storage"]
[tool.uv.sources]
deerflow-harness = { workspace = true }
deerflow-storage = { workspace = true }
+93
View File
@@ -4,6 +4,8 @@ Sets up sys.path and pre-mocks modules that would cause circular import
issues when unit-testing lightweight config/registry code in isolation.
"""
from __future__ import annotations
import importlib.util
import sys
from pathlib import Path
@@ -11,11 +13,16 @@ from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from support.detectors.blocking_io import BlockingIOProbe, detect_blocking_io
# Make 'app' and 'deerflow' importable from any working directory
sys.path.insert(0, str(Path(__file__).parent.parent))
sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "scripts"))
_BACKEND_ROOT = Path(__file__).resolve().parents[1]
_blocking_io_probe = BlockingIOProbe(_BACKEND_ROOT)
_BLOCKING_IO_DETECTOR_ATTR = "_blocking_io_detector"
# Break the circular import chain that exists in production code:
# deerflow.subagents.__init__
# -> .executor (SubagentExecutor, SubagentResult)
@@ -56,6 +63,92 @@ def provisioner_module():
return module
@pytest.fixture()
def blocking_io_detector():
"""Fail a focused test if blocking calls run on the event loop thread."""
with detect_blocking_io(fail_on_exit=True) as detector:
yield detector
def pytest_addoption(parser: pytest.Parser) -> None:
group = parser.getgroup("blocking-io")
group.addoption(
"--detect-blocking-io",
action="store_true",
default=False,
help="Collect blocking calls made while an asyncio event loop is running and report a summary.",
)
group.addoption(
"--detect-blocking-io-fail",
action="store_true",
default=False,
help="Set a failing exit status when --detect-blocking-io records violations.",
)
def pytest_configure(config: pytest.Config) -> None:
config.addinivalue_line("markers", "no_blocking_io_probe: skip the optional blocking IO probe")
def pytest_sessionstart(session: pytest.Session) -> None:
if _blocking_io_probe_enabled(session.config):
_blocking_io_probe.clear()
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_call(item: pytest.Item):
if not _blocking_io_probe_enabled(item.config) or _blocking_io_probe_skipped(item):
yield
return
detector = detect_blocking_io(fail_on_exit=False, stack_limit=18)
detector.__enter__()
setattr(item, _BLOCKING_IO_DETECTOR_ATTR, detector)
yield
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_teardown(item: pytest.Item):
yield
detector = getattr(item, _BLOCKING_IO_DETECTOR_ATTR, None)
if detector is None:
return
try:
detector.__exit__(None, None, None)
_blocking_io_probe.record(item.nodeid, detector.violations)
finally:
delattr(item, _BLOCKING_IO_DETECTOR_ATTR)
def pytest_sessionfinish(session: pytest.Session) -> None:
if _blocking_io_fail_enabled(session.config) and _blocking_io_probe.violation_count and session.exitstatus == pytest.ExitCode.OK:
session.exitstatus = pytest.ExitCode.TESTS_FAILED
def pytest_terminal_summary(terminalreporter: pytest.TerminalReporter) -> None:
if not _blocking_io_probe_enabled(terminalreporter.config):
return
header, *details = _blocking_io_probe.format_summary().splitlines()
terminalreporter.write_sep("=", header)
for line in details:
terminalreporter.write_line(line)
def _blocking_io_probe_enabled(config: pytest.Config) -> bool:
return bool(config.getoption("--detect-blocking-io") or config.getoption("--detect-blocking-io-fail"))
def _blocking_io_fail_enabled(config: pytest.Config) -> bool:
return bool(config.getoption("--detect-blocking-io-fail"))
def _blocking_io_probe_skipped(item: pytest.Item) -> bool:
return item.path.name == "test_blocking_io_detector.py" or item.get_closest_marker("no_blocking_io_probe") is not None
# ---------------------------------------------------------------------------
# Auto-set user context for every test unless marked no_auto_user
# ---------------------------------------------------------------------------
+1
View File
@@ -0,0 +1 @@
"""Shared test support helpers."""
@@ -0,0 +1 @@
"""Runtime and static detectors used by tests."""
@@ -0,0 +1,287 @@
"""Test helper for detecting blocking calls on an asyncio event loop.
The detector is intentionally test-only. It monkeypatches a small set of
well-known blocking entry points and their already-loaded module-level aliases,
then records calls only when they happen on a thread that is currently running
an asyncio event loop. Aliases captured in closures or default arguments remain
out of scope.
"""
from __future__ import annotations
import asyncio
import importlib
import sys
import traceback
from collections import Counter
from collections.abc import Callable, Iterable, Iterator
from contextlib import AbstractContextManager
from dataclasses import dataclass
from functools import wraps
from pathlib import Path
from types import TracebackType
from typing import Any
BlockingCallable = Callable[..., Any]
@dataclass(frozen=True)
class BlockingCallSpec:
"""Describes one blocking callable to wrap during a detector run."""
name: str
target: str
record_on_iteration: bool = False
@dataclass(frozen=True)
class BlockingCall:
"""One blocking call observed on an asyncio event loop thread."""
name: str
target: str
stack: tuple[traceback.FrameSummary, ...]
DEFAULT_BLOCKING_CALL_SPECS: tuple[BlockingCallSpec, ...] = (
BlockingCallSpec("time.sleep", "time:sleep"),
BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),
BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),
BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),
BlockingCallSpec("pathlib.Path.resolve", "pathlib:Path.resolve"),
BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),
BlockingCallSpec("pathlib.Path.write_text", "pathlib:Path.write_text"),
)
def _is_event_loop_thread() -> bool:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return False
return loop.is_running()
def _resolve_target(target: str) -> tuple[object, str, BlockingCallable]:
module_name, attr_path = target.split(":", maxsplit=1)
owner: object = importlib.import_module(module_name)
parts = attr_path.split(".")
for part in parts[:-1]:
owner = getattr(owner, part)
attr_name = parts[-1]
original = getattr(owner, attr_name)
return owner, attr_name, original
def _trim_detector_frames(stack: Iterable[traceback.FrameSummary]) -> tuple[traceback.FrameSummary, ...]:
return tuple(frame for frame in stack if frame.filename != __file__)
class BlockingIODetector(AbstractContextManager["BlockingIODetector"]):
"""Record blocking calls made from async runtime code.
By default the detector reports violations but does not fail on context
exit. Tests can set ``fail_on_exit=True`` or call
``assert_no_blocking_calls()`` explicitly.
"""
def __init__(
self,
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
*,
fail_on_exit: bool = False,
patch_loaded_aliases: bool = True,
stack_limit: int = 12,
) -> None:
self._specs = tuple(specs)
self._fail_on_exit = fail_on_exit
self._patch_loaded_aliases_enabled = patch_loaded_aliases
self._stack_limit = stack_limit
self._patches: list[tuple[object, str, BlockingCallable]] = []
self._patch_keys: set[tuple[int, str]] = set()
self.violations: list[BlockingCall] = []
self._active = False
def __enter__(self) -> BlockingIODetector:
try:
self._active = True
alias_replacements: dict[int, BlockingCallable] = {}
for spec in self._specs:
owner, attr_name, original = _resolve_target(spec.target)
wrapper = self._wrap(spec, original)
self._patch_attribute(owner, attr_name, original, wrapper)
alias_replacements[id(original)] = wrapper
if self._patch_loaded_aliases_enabled:
self._patch_loaded_module_aliases(alias_replacements)
except Exception:
self._restore()
self._active = False
raise
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback_value: TracebackType | None,
) -> bool | None:
self._restore()
self._active = False
if exc_type is None and self._fail_on_exit:
self.assert_no_blocking_calls()
return None
def _restore(self) -> None:
for owner, attr_name, original in reversed(self._patches):
setattr(owner, attr_name, original)
self._patches.clear()
self._patch_keys.clear()
def _patch_attribute(self, owner: object, attr_name: str, original: BlockingCallable, replacement: BlockingCallable) -> None:
key = (id(owner), attr_name)
if key in self._patch_keys:
return
setattr(owner, attr_name, replacement)
self._patches.append((owner, attr_name, original))
self._patch_keys.add(key)
def _patch_loaded_module_aliases(self, replacements_by_id: dict[int, BlockingCallable]) -> None:
for module in tuple(sys.modules.values()):
namespace = getattr(module, "__dict__", None)
if not isinstance(namespace, dict):
continue
for attr_name, value in tuple(namespace.items()):
replacement = replacements_by_id.get(id(value))
if replacement is not None:
self._patch_attribute(module, attr_name, value, replacement)
def _wrap(self, spec: BlockingCallSpec, original: BlockingCallable) -> BlockingCallable:
@wraps(original)
def wrapper(*args: Any, **kwargs: Any) -> Any:
if spec.record_on_iteration:
result = original(*args, **kwargs)
return self._wrap_iteration(spec, result)
self._record_if_blocking(spec)
return original(*args, **kwargs)
return wrapper
def _wrap_iteration(self, spec: BlockingCallSpec, iterable: Iterable[Any]) -> Iterator[Any]:
iterator = iter(iterable)
reported = False
while True:
if not reported:
reported = self._record_if_blocking(spec)
try:
yield next(iterator)
except StopIteration:
return
def _record_if_blocking(self, spec: BlockingCallSpec) -> bool:
if self._active and _is_event_loop_thread():
stack = _trim_detector_frames(traceback.extract_stack(limit=self._stack_limit))
self.violations.append(BlockingCall(spec.name, spec.target, stack))
return True
return False
def assert_no_blocking_calls(self) -> None:
if self.violations:
raise AssertionError(format_blocking_calls(self.violations))
class BlockingIOProbe:
"""Collect detector output across tests and format a compact summary."""
def __init__(self, project_root: Path) -> None:
self._project_root = project_root.resolve()
self._observed: list[tuple[str, BlockingCall]] = []
@property
def violation_count(self) -> int:
return len(self._observed)
@property
def test_count(self) -> int:
return len({nodeid for nodeid, _violation in self._observed})
def clear(self) -> None:
self._observed.clear()
def record(self, nodeid: str, violations: Iterable[BlockingCall]) -> None:
for violation in violations:
self._observed.append((nodeid, violation))
def format_summary(self, *, limit: int = 30) -> str:
if not self._observed:
return "blocking io probe: no violations"
call_sites: Counter[tuple[str, str, int, str, str]] = Counter()
for _nodeid, violation in self._observed:
frame = self._local_call_site(violation.stack)
if frame is None:
call_sites[(violation.name, "<unknown>", 0, "<unknown>", "")] += 1
continue
call_sites[
(
violation.name,
self._relative(frame.filename),
frame.lineno,
frame.name,
(frame.line or "").strip(),
)
] += 1
lines = [f"blocking io probe: {self.violation_count} violations across {self.test_count} tests", "Top call sites:"]
for (name, filename, lineno, function, line), count in call_sites.most_common(limit):
lines.append(f"{count:4d} {name} {filename}:{lineno} {function} | {line}")
return "\n".join(lines)
def _relative(self, filename: str) -> str:
try:
return str(Path(filename).resolve().relative_to(self._project_root))
except ValueError:
return filename
def _local_call_site(self, stack: tuple[traceback.FrameSummary, ...]) -> traceback.FrameSummary | None:
local_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename and not self._relative(frame.filename).startswith("tests/")]
if local_frames:
return local_frames[-1]
test_frames = [frame for frame in stack if str(self._project_root) in frame.filename and "/.venv/" not in frame.filename]
return test_frames[-1] if test_frames else None
def detect_blocking_io(
specs: Iterable[BlockingCallSpec] = DEFAULT_BLOCKING_CALL_SPECS,
*,
fail_on_exit: bool = False,
patch_loaded_aliases: bool = True,
stack_limit: int = 12,
) -> BlockingIODetector:
"""Create a detector context manager for a focused test scope."""
return BlockingIODetector(specs, fail_on_exit=fail_on_exit, patch_loaded_aliases=patch_loaded_aliases, stack_limit=stack_limit)
def format_blocking_calls(violations: Iterable[BlockingCall]) -> str:
"""Format detector output with enough stack context to locate call sites."""
lines = ["Blocking calls were executed on an asyncio event loop thread:"]
for index, violation in enumerate(violations, start=1):
lines.append(f"{index}. {violation.name} ({violation.target})")
lines.extend(_format_stack(violation.stack))
return "\n".join(lines)
def _format_stack(stack: Iterable[traceback.FrameSummary]) -> Iterator[str]:
for frame in stack:
location = f"{frame.filename}:{frame.lineno}"
lines = [f" at {frame.name} ({location})"]
if frame.line:
lines.append(f" {frame.line.strip()}")
yield from lines
+190
View File
@@ -0,0 +1,190 @@
from __future__ import annotations
import asyncio
import os
import time
from os import walk as imported_walk
from pathlib import Path
from time import sleep as imported_sleep
import httpx
import pytest
import requests
from support.detectors.blocking_io import (
BlockingCallSpec,
BlockingIOProbe,
detect_blocking_io,
)
pytestmark = pytest.mark.asyncio
TIME_SLEEP_ONLY = (BlockingCallSpec("time.sleep", "time:sleep"),)
REQUESTS_ONLY = (BlockingCallSpec("requests.Session.request", "requests.sessions:Session.request"),)
HTTPX_ONLY = (BlockingCallSpec("httpx.Client.request", "httpx:Client.request"),)
OS_WALK_ONLY = (BlockingCallSpec("os.walk", "os:walk", record_on_iteration=True),)
PATH_READ_TEXT_ONLY = (BlockingCallSpec("pathlib.Path.read_text", "pathlib:Path.read_text"),)
async def test_records_time_sleep_on_event_loop() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
time.sleep(0)
assert [violation.name for violation in detector.violations] == ["time.sleep"]
async def test_records_already_imported_sleep_alias_on_event_loop() -> None:
original_alias = imported_sleep
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
imported_sleep(0)
assert imported_sleep is original_alias
assert [violation.name for violation in detector.violations] == ["time.sleep"]
async def test_can_disable_loaded_alias_patching() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY, patch_loaded_aliases=False) as detector:
imported_sleep(0)
assert detector.violations == []
async def test_does_not_record_time_sleep_offloaded_to_thread() -> None:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
await asyncio.to_thread(time.sleep, 0)
assert detector.violations == []
async def test_fixture_allows_offloaded_sync_work(blocking_io_detector) -> None:
await asyncio.to_thread(time.sleep, 0)
assert blocking_io_detector.violations == []
async def test_does_not_record_sync_call_without_running_event_loop() -> None:
def call_sleep() -> list[str]:
with detect_blocking_io(TIME_SLEEP_ONLY) as detector:
time.sleep(0)
return [violation.name for violation in detector.violations]
assert await asyncio.to_thread(call_sleep) == []
async def test_fail_on_exit_includes_call_site() -> None:
with pytest.raises(AssertionError) as exc_info:
with detect_blocking_io(TIME_SLEEP_ONLY, fail_on_exit=True):
time.sleep(0)
message = str(exc_info.value)
assert "time.sleep" in message
assert "test_fail_on_exit_includes_call_site" in message
async def test_records_requests_session_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_request(self: requests.Session, method: str, url: str, **kwargs: object) -> str:
return f"{method}:{url}"
monkeypatch.setattr(requests.sessions.Session, "request", fake_request)
with detect_blocking_io(REQUESTS_ONLY) as detector:
assert requests.get("https://example.invalid") == "get:https://example.invalid"
assert [violation.name for violation in detector.violations] == ["requests.Session.request"]
async def test_records_sync_httpx_client_request_without_real_network(monkeypatch: pytest.MonkeyPatch) -> None:
def fake_request(self: httpx.Client, method: str, url: str, **kwargs: object) -> httpx.Response:
return httpx.Response(200, request=httpx.Request(method, url))
monkeypatch.setattr(httpx.Client, "request", fake_request)
with detect_blocking_io(HTTPX_ONLY) as detector:
with httpx.Client() as client:
response = client.get("https://example.invalid")
assert response.status_code == 200
assert [violation.name for violation in detector.violations] == ["httpx.Client.request"]
async def test_records_os_walk_on_event_loop(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
with detect_blocking_io(OS_WALK_ONLY) as detector:
assert list(os.walk(tmp_path))
assert [violation.name for violation in detector.violations] == ["os.walk"]
async def test_records_already_imported_os_walk_alias_on_iteration(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
original_alias = imported_walk
with detect_blocking_io(OS_WALK_ONLY) as detector:
assert list(imported_walk(tmp_path))
assert imported_walk is original_alias
assert [violation.name for violation in detector.violations] == ["os.walk"]
async def test_does_not_record_os_walk_before_iteration(tmp_path: Path) -> None:
with detect_blocking_io(OS_WALK_ONLY) as detector:
walker = os.walk(tmp_path)
assert list(walker)
assert detector.violations == []
async def test_does_not_record_os_walk_iterated_off_event_loop(tmp_path: Path) -> None:
(tmp_path / "nested").mkdir()
with detect_blocking_io(OS_WALK_ONLY) as detector:
walker = os.walk(tmp_path)
assert await asyncio.to_thread(lambda: list(walker))
assert detector.violations == []
async def test_records_path_read_text_on_event_loop(tmp_path: Path) -> None:
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY) as detector:
assert path.read_text(encoding="utf-8") == "content"
assert [violation.name for violation in detector.violations] == ["pathlib.Path.read_text"]
async def test_probe_formats_summary_for_recorded_violations(tmp_path: Path) -> None:
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
assert path.read_text(encoding="utf-8") == "content"
probe.record("tests/test_example.py::test_example", detector.violations)
summary = probe.format_summary()
assert "blocking io probe: 1 violations across 1 tests" in summary
assert "pathlib.Path.read_text" in summary
async def test_probe_formats_empty_summary_and_can_be_cleared(tmp_path: Path) -> None:
probe = BlockingIOProbe(Path(__file__).resolve().parents[1])
assert probe.format_summary() == "blocking io probe: no violations"
path = tmp_path / "data.txt"
path.write_text("content", encoding="utf-8")
with detect_blocking_io(PATH_READ_TEXT_ONLY, stack_limit=18) as detector:
assert path.read_text(encoding="utf-8") == "content"
probe.record("tests/test_example.py::test_example", detector.violations)
assert probe.violation_count == 1
probe.clear()
assert probe.violation_count == 0
assert probe.format_summary() == "blocking io probe: no violations"
@@ -0,0 +1,22 @@
from __future__ import annotations
import time
import pytest
ORIGINAL_SLEEP = time.sleep
def replacement_sleep(seconds: float) -> None:
return None
def test_probe_survives_monkeypatch_teardown(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(time, "sleep", replacement_sleep)
assert time.sleep is replacement_sleep
@pytest.mark.no_blocking_io_probe
def test_probe_restores_original_after_monkeypatch_teardown() -> None:
assert time.sleep is ORIGINAL_SLEEP
assert getattr(time.sleep, "__wrapped__", None) is None
+5 -2
View File
@@ -94,12 +94,15 @@ class TestHarnessPackaging:
"psycopg-pool>=3.3.0",
]
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
def test_workspace_pyproject_forwards_postgres_extra_to_storage_packages(self):
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
assert optional_dependencies["postgres"] == [
"deerflow-harness[postgres]",
"deerflow-storage[postgres]",
]
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
@@ -0,0 +1,222 @@
"""Real-LLM end-to-end verification for issue #2884.
Drives a real ``langchain.agents.create_agent`` graph against a real OpenAI-
compatible LLM (one-api gateway), bound through ``DeferredToolFilterMiddleware``
and the production ``get_available_tools`` pipeline. The only thing we mock is
the MCP tool source we hand-roll two ``@tool``s and inject them through
``deerflow.mcp.cache.get_cached_mcp_tools``.
The flow exercised:
1. Turn 1: agent sees ``tool_search`` (plus a ``fake_subagent_trigger``
that re-enters ``get_available_tools`` on the same task this is the
code path issue #2884 reports). It must call ``tool_search`` to
discover the deferred ``fake_calculator`` tool.
2. Tool batch: ``tool_search`` promotes ``fake_calculator``;
``fake_subagent_trigger`` re-enters ``get_available_tools``.
3. Turn 2: the promoted ``fake_calculator`` schema must reach the model
so it can actually call it. Without this PR's fix, the re-entry wipes
the promotion and the model can no longer invoke the tool.
Skipped unless ``ONEAPI_E2E=1`` is set so this doesn't burn credits on every
test run. Run with::
ONEAPI_E2E=1 OPENAI_API_KEY=... OPENAI_API_BASE=... \
PYTHONPATH=. uv run pytest \
tests/test_deferred_tool_promotion_real_llm.py -v -s
"""
from __future__ import annotations
import os
import pytest
from langchain_core.messages import HumanMessage
from langchain_core.tools import tool as as_tool
# ---------------------------------------------------------------------------
# Skip control: only run when explicitly opted in.
# ---------------------------------------------------------------------------
pytestmark = pytest.mark.skipif(
os.getenv("ONEAPI_E2E") != "1",
reason="Real-LLM e2e: opt in with ONEAPI_E2E=1 (requires OPENAI_API_KEY + OPENAI_API_BASE)",
)
# ---------------------------------------------------------------------------
# Fake "MCP" tools the agent should discover via tool_search.
# Keep them obviously synthetic so the model can pattern-match the search.
# ---------------------------------------------------------------------------
_calls: list[str] = []
@as_tool
def fake_calculator(expression: str) -> str:
"""Evaluate a tiny arithmetic expression like '2 + 2'.
Reserved for the user only call this if the user asks for arithmetic.
"""
_calls.append(f"fake_calculator:{expression}")
try:
# Trivially safe-eval just for the e2e check
allowed = set("0123456789+-*/() .")
if not set(expression) <= allowed:
return "expression contains disallowed characters"
return str(eval(expression, {"__builtins__": {}}, {})) # noqa: S307
except Exception as e:
return f"error: {e}"
@as_tool
def fake_translator(text: str, target_lang: str) -> str:
"""Translate text into the given language code. Decorative — not used."""
_calls.append(f"fake_translator:{text}:{target_lang}")
return f"[{target_lang}] {text}"
# ---------------------------------------------------------------------------
# Pipeline wiring (same shape as the in-process tests).
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _reset_registry_between_tests():
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Build a minimal mock AppConfig and patch the symbol — never call the
real loader, which would trigger ``_apply_singleton_configs`` and
permanently mutate cross-test singletons (memory, title, )."""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Real-LLM e2e test
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_real_llm_promotes_then_invokes_with_subagent_reentry(monkeypatch: pytest.MonkeyPatch):
"""End-to-end against a real OpenAI-compatible LLM.
The model must:
Turn 1 see ``tool_search`` (deferred tools aren't bound yet) and
batch-call BOTH ``tool_search(select:fake_calculator)`` AND
``fake_subagent_trigger(...)``.
Turn 2 call ``fake_calculator`` and finish.
Pass criterion: ``fake_calculator`` actually gets invoked at the tool
layer recorded in ``_calls`` which proves the model received the
promoted schema after the re-entrant ``get_available_tools`` call.
"""
from langchain.agents import create_agent
from langchain_openai import ChatOpenAI
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_calculator, fake_translator])
_force_tool_search_enabled(monkeypatch)
_calls.clear()
@as_tool
async def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset.
Use this whenever the user asks you to delegate work pass a short
description as ``prompt``.
"""
# ``task_tool`` does this internally. Whether the registry-reset that
# used to happen here actually leaks back to the parent task depends
# on asyncio's implicit context-copying semantics (gather creates
# child tasks with copied contexts, so reset_deferred_registry is
# task-local) — but the fix in this PR is what GUARANTEES the
# promotion sticks regardless of which integration path triggers a
# re-entrant ``get_available_tools`` call.
get_available_tools(subagent_enabled=False)
_calls.append(f"fake_subagent_trigger:{prompt}")
return "subagent completed"
tools = get_available_tools() + [fake_subagent_trigger]
model = ChatOpenAI(
model=os.environ.get("ONEAPI_MODEL", "claude-sonnet-4-6"),
api_key=os.environ["OPENAI_API_KEY"],
base_url=os.environ["OPENAI_API_BASE"],
temperature=0,
max_retries=1,
)
system_prompt = (
"You are a meticulous assistant. Available deferred tools include a "
"calculator and a translator — their schemas are hidden until you "
"search for them via tool_search.\n\n"
"Procedure for the user's request:\n"
" 1. Call tool_search with query 'select:fake_calculator' AND "
"in the SAME tool batch also call fake_subagent_trigger(prompt='go') "
"to delegate the side work. Put both tool_calls in your first response.\n"
" 2. After both tool messages come back, call fake_calculator with "
"the user's expression.\n"
" 3. Reply with just the numeric result."
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt=system_prompt,
)
result = await graph.ainvoke(
{"messages": [HumanMessage(content="What is 17 * 23? Use the deferred calculator tool.")]},
config={"recursion_limit": 12},
)
print("\n=== tool calls recorded ===")
for c in _calls:
print(f" {c}")
print("\n=== final message ===")
final_text = result["messages"][-1].content if result["messages"] else "(none)"
print(f" {final_text!r}")
# The smoking-gun assertion: fake_calculator was actually invoked at the
# tool layer. This is only possible if the promoted schema reached the
# model in turn 2, despite the subagent-style re-entry in turn 1.
calc_calls = [c for c in _calls if c.startswith("fake_calculator:")]
assert calc_calls, f"REGRESSION (#2884): the model never managed to call fake_calculator. All recorded tool calls: {_calls!r}. Final text: {final_text!r}"
# And the math should actually be done correctly (sanity that the LLM
# really used the result, not just hallucinated the answer).
assert "391" in str(final_text), f"Model didn't surface 17*23=391. Final text: {final_text!r}"
@@ -0,0 +1,390 @@
"""Reproduce + regression-guard issue #2884.
Hypothesis from the issue:
``tools.tools.get_available_tools`` unconditionally calls
``reset_deferred_registry()`` and constructs a fresh ``DeferredToolRegistry``
every time it is invoked. If anything calls ``get_available_tools`` again
during the same async context (after the agent has promoted tools via
``tool_search``), the promotion is wiped and the next model call hides the
tool's schema again.
These tests pin two things:
A. **At the unit boundary** verify the failure mode directly. Promote a
tool in the registry, then call ``get_available_tools`` again and observe
that the ContextVar registry is reset and the promotion is lost.
B. **At the graph-execution boundary** drive a real ``create_agent`` graph
with the real ``DeferredToolFilterMiddleware`` through two model turns.
The first turn calls ``tool_search`` which promotes a tool. The second
turn must see that tool's schema in ``request.tools``. If
``get_available_tools`` were to run again between the two turns and reset
the registry, the second turn's filter would strip the tool.
Strategy: use the production ``deerflow.tools.tools.get_available_tools``
unmodified; mock only the LLM and the MCP tool source. Patch
``deerflow.mcp.cache.get_cached_mcp_tools`` (the symbol that
``get_available_tools`` resolves via lazy import) to return our fixture
tools so we don't need a real MCP server.
"""
from __future__ import annotations
from typing import Any
import pytest
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.runnables import Runnable
from langchain_core.tools import tool as as_tool
class FakeToolCallingModel(FakeMessagesListChatModel):
"""FakeMessagesListChatModel + no-op bind_tools so create_agent works."""
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
return self
# ---------------------------------------------------------------------------
# Fixtures: a fake MCP tool source + a way to force config.tool_search.enabled
# ---------------------------------------------------------------------------
@as_tool
def fake_mcp_search(query: str) -> str:
"""Pretend to search a knowledge base for the given query."""
return f"results for {query}"
@as_tool
def fake_mcp_fetch(url: str) -> str:
"""Pretend to fetch a page at the given URL."""
return f"content of {url}"
@pytest.fixture(autouse=True)
def _supply_env(monkeypatch: pytest.MonkeyPatch):
"""config.yaml references $OPENAI_API_KEY at parse time; supply a placeholder."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-fake-not-used")
monkeypatch.setenv("OPENAI_API_BASE", "https://example.invalid")
@pytest.fixture(autouse=True)
def _reset_deferred_registry_between_tests():
"""Each test must start with a clean ContextVar.
The registry lives in a module-level ContextVar with no per-task isolation
in a synchronous test runner, so one test's promotion can leak into the
next and silently break filter assertions.
"""
from deerflow.tools.builtins.tool_search import reset_deferred_registry
reset_deferred_registry()
yield
reset_deferred_registry()
def _patch_mcp_pipeline(monkeypatch: pytest.MonkeyPatch, mcp_tools: list) -> None:
"""Make get_available_tools believe an MCP server is registered.
Build a real ``ExtensionsConfig`` with one enabled MCP server entry so
that both ``AppConfig.from_file`` (which calls
``ExtensionsConfig.from_file().model_dump()``) and ``tools.get_available_tools``
(which calls ``ExtensionsConfig.from_file().get_enabled_mcp_servers()``)
see a valid instance. Then point the MCP tool cache at our fixture tools.
"""
from deerflow.config.extensions_config import ExtensionsConfig, McpServerConfig
real_ext = ExtensionsConfig(
mcpServers={"fake-server": McpServerConfig(type="stdio", command="echo", enabled=True)},
)
monkeypatch.setattr(
"deerflow.config.extensions_config.ExtensionsConfig.from_file",
classmethod(lambda cls: real_ext),
)
monkeypatch.setattr("deerflow.mcp.cache.get_cached_mcp_tools", lambda: list(mcp_tools))
def _force_tool_search_enabled(monkeypatch: pytest.MonkeyPatch) -> None:
"""Force config.tool_search.enabled=True without touching the yaml.
Calling the real ``get_app_config()`` would trigger ``_apply_singleton_configs``
which permanently mutates module-level singletons (``_memory_config``,
``_title_config``, ) to match the developer's ``config.yaml`` — even
after pytest restores our patch. That leaks across tests later in the
run that rely on those singletons' DEFAULTS (e.g. memory queue tests
require ``_memory_config.enabled = True``, which is the dataclass default
but FALSE in the actual yaml).
Build a minimal mock AppConfig instead and never call the real loader.
"""
from deerflow.config.app_config import AppConfig
from deerflow.config.tool_search_config import ToolSearchConfig
mock_cfg = AppConfig.model_construct(
log_level="info",
models=[],
tools=[],
tool_groups=[],
sandbox=AppConfig.model_fields["sandbox"].annotation.model_construct(use="x"),
tool_search=ToolSearchConfig(enabled=True),
)
monkeypatch.setattr("deerflow.tools.tools.get_app_config", lambda: mock_cfg)
# ---------------------------------------------------------------------------
# Section A — direct unit-level reproduction
# ---------------------------------------------------------------------------
def test_get_available_tools_preserves_promotions_across_reentrant_calls(monkeypatch: pytest.MonkeyPatch):
"""Re-entrant ``get_available_tools()`` must preserve prior promotions.
Step 1: call get_available_tools() registers MCP tools as deferred.
Step 2: simulate the agent calling tool_search by promoting one tool.
Step 3: call get_available_tools() again (the same code path
``task_tool`` exercises mid-run).
Assertion: after step 3, the promoted tool is STILL promoted (not
re-deferred). On ``main`` before the fix, step 3's
``reset_deferred_registry()`` wiped the promotion and re-registered
every MCP tool as deferred this assertion fired with
``REGRESSION (#2884)``.
"""
from deerflow.tools.builtins.tool_search import get_deferred_registry
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# Step 1: first call — both MCP tools start deferred
get_available_tools()
reg1 = get_deferred_registry()
assert reg1 is not None
assert {e.name for e in reg1.entries} == {"fake_mcp_search", "fake_mcp_fetch"}
# Step 2: simulate tool_search promoting one of them
reg1.promote({"fake_mcp_search"})
assert {e.name for e in reg1.entries} == {"fake_mcp_fetch"}, "Sanity: promote should remove fake_mcp_search"
# Step 3: second call — registry must NOT silently undo the promotion
get_available_tools()
reg2 = get_deferred_registry()
assert reg2 is not None
deferred_after = {e.name for e in reg2.entries}
assert "fake_mcp_search" not in deferred_after, f"REGRESSION (#2884): get_available_tools wiped the deferred registry, re-deferring a tool that was already promoted by tool_search. deferred_after_second_call={deferred_after!r}"
# ---------------------------------------------------------------------------
# Section B — graph-execution reproduction
# ---------------------------------------------------------------------------
class _ToolSearchPromotingModel(FakeToolCallingModel):
"""Two-turn model that:
Turn 1 emit a tool_call for ``tool_search`` (the real one)
Turn 2 emit a tool_call for ``fake_mcp_search`` (the promoted tool)
Records the tools it received on each turn so the test can inspect what
DeferredToolFilterMiddleware actually fed to ``bind_tools``.
"""
bound_tools_per_turn: list[list[str]] = []
def bind_tools( # type: ignore[override]
self,
tools: Any,
*,
tool_choice: Any = None,
**kwargs: Any,
) -> Runnable:
# Record the tool names the model would see in this turn
names = [getattr(t, "name", getattr(t, "__name__", repr(t))) for t in tools]
self.bound_tools_per_turn.append(names)
return self
def _build_promoting_model() -> _ToolSearchPromotingModel:
return _ToolSearchPromotingModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
}
],
),
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
def test_promoted_tool_is_visible_to_model_on_second_turn(monkeypatch: pytest.MonkeyPatch):
"""End-to-end: drive a real create_agent graph through two turns.
Without the fix, the second-turn bind_tools call should NOT contain
fake_mcp_search (because DeferredToolFilterMiddleware sees it in the
registry and strips it). With the fix, the model sees the schema and can
invoke it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
tools = get_available_tools()
# Sanity: the assembled tool list includes the deferred tools (they're in
# bind_tools but DeferredToolFilterMiddleware strips deferred ones before
# they reach the model)
tool_names = {getattr(t, "name", "") for t in tools}
assert {"tool_search", "fake_mcp_search", "fake_mcp_fetch"} <= tool_names
model = _build_promoting_model()
model.bound_tools_per_turn = [] # reset class-level recorder
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1: model should NOT see fake_mcp_search (it's deferred)
turn1 = set(model.bound_tools_per_turn[0])
assert "fake_mcp_search" not in turn1, f"Turn 1 sanity: deferred tools must be hidden from the model. Saw: {turn1!r}"
assert "tool_search" in turn1, f"Turn 1 sanity: tool_search must be visible so the agent can discover. Saw: {turn1!r}"
# Turn 2: AFTER tool_search promotes fake_mcp_search, the model must see it.
# This is the load-bearing assertion for issue #2884.
assert len(model.bound_tools_per_turn) >= 2, f"Expected at least 2 model turns, got {len(model.bound_tools_per_turn)}"
turn2 = set(model.bound_tools_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): tool_search promoted fake_mcp_search in turn 1, but the deferred-tool filter still hid it from the model in turn 2. Turn 2 bound tools: {turn2!r}"
# ---------------------------------------------------------------------------
# Section C — the actual issue #2884 trigger: a re-entrant
# get_available_tools call (e.g. when task_tool spawns a subagent) must not
# wipe the parent's promotion.
# ---------------------------------------------------------------------------
def test_reentrant_get_available_tools_preserves_promotion(monkeypatch: pytest.MonkeyPatch):
"""Issue #2884 in its real shape: a re-entrant get_available_tools call
(the same pattern that happens when ``task_tool`` builds a subagent's
toolset mid-run) must not wipe the parent agent's tool_search promotions.
Turn 1's tool batch contains BOTH ``tool_search`` (which promotes
``fake_mcp_search``) AND ``fake_subagent_trigger`` (which calls
``get_available_tools`` again exactly what ``task_tool`` does when it
builds a subagent's toolset). With the fix, turn 2's bind_tools sees the
promoted tool. Without the fix, the re-entry wipes the registry and
the filter re-hides it.
"""
from langchain.agents import create_agent
from deerflow.agents.middlewares.deferred_tool_filter_middleware import DeferredToolFilterMiddleware
from deerflow.tools.tools import get_available_tools
_patch_mcp_pipeline(monkeypatch, [fake_mcp_search, fake_mcp_fetch])
_force_tool_search_enabled(monkeypatch)
# The trigger tool simulates what task_tool does internally: rebuild the
# toolset by calling get_available_tools while the registry is live.
@as_tool
def fake_subagent_trigger(prompt: str) -> str:
"""Pretend to spawn a subagent. Internally rebuilds the toolset."""
get_available_tools(subagent_enabled=False)
return f"spawned subagent for: {prompt}"
tools = get_available_tools() + [fake_subagent_trigger]
bound_per_turn: list[list[str]] = []
class _Model(FakeToolCallingModel):
def bind_tools(self, tools_arg, **kwargs): # type: ignore[override]
bound_per_turn.append([getattr(t, "name", repr(t)) for t in tools_arg])
return self
model = _Model(
responses=[
# Turn 1: do both in one batch — promote AND trigger the
# subagent-style rebuild. LangGraph executes them in order in the
# same agent step.
AIMessage(
content="",
tool_calls=[
{
"name": "tool_search",
"args": {"query": "select:fake_mcp_search"},
"id": "call_search_1",
"type": "tool_call",
},
{
"name": "fake_subagent_trigger",
"args": {"prompt": "go"},
"id": "call_trigger_1",
"type": "tool_call",
},
],
),
# Turn 2: try to invoke the promoted tool. The model gets this
# turn only if turn 1's bind_tools recorded what the filter sent.
AIMessage(
content="",
tool_calls=[
{
"name": "fake_mcp_search",
"args": {"query": "hello"},
"id": "call_mcp_1",
"type": "tool_call",
}
],
),
AIMessage(content="all done"),
]
)
graph = create_agent(
model=model,
tools=tools,
middleware=[DeferredToolFilterMiddleware()],
system_prompt="bug-2884-subagent-repro",
)
graph.invoke({"messages": [HumanMessage(content="use the search tool")]})
# Turn 1 sanity: deferred tool not visible yet
assert "fake_mcp_search" not in set(bound_per_turn[0]), bound_per_turn[0]
# The smoking-gun assertion: turn 2 sees the promoted tool DESPITE the
# re-entrant get_available_tools call that happened in turn 1's tool batch.
assert len(bound_per_turn) >= 2, f"Expected ≥2 turns, got {len(bound_per_turn)}"
turn2 = set(bound_per_turn[1])
assert "fake_mcp_search" in turn2, f"REGRESSION (#2884): a re-entrant get_available_tools call (e.g. task_tool spawning a subagent) wiped the parent agent's promotion. Turn 2 bound tools: {turn2!r}"
+83 -1
View File
@@ -1,6 +1,6 @@
import threading
import time
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock, call, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
@@ -164,3 +164,85 @@ def test_flush_nowait_is_non_blocking() -> None:
assert elapsed < 0.1
assert finished.is_set() is False
assert finished.wait(1.0) is True
def test_queue_keeps_updates_for_different_agents_in_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
assert queue.pending_count == 2
assert [context.agent_name for context in queue._queue] == ["agent-a", "agent-b"]
def test_queue_still_coalesces_updates_for_same_agent_in_same_thread() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(
thread_id="thread-1",
messages=["first"],
agent_name="agent-a",
correction_detected=True,
)
queue.add(
thread_id="thread-1",
messages=["second"],
agent_name="agent-a",
correction_detected=False,
)
assert queue.pending_count == 1
assert queue._queue[0].agent_name == "agent-a"
assert queue._queue[0].messages == ["second"]
assert queue._queue[0].correction_detected is True
def test_process_queue_updates_different_agents_in_same_thread_separately() -> None:
queue = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=_memory_config(enabled=True)),
patch.object(queue, "_reset_timer"),
):
queue.add(thread_id="thread-1", messages=["agent-a"], agent_name="agent-a")
queue.add(thread_id="thread-1", messages=["agent-b"], agent_name="agent-b")
mock_updater = MagicMock()
mock_updater.update_memory.return_value = True
with (
patch("deerflow.agents.memory.updater.MemoryUpdater", return_value=mock_updater),
patch("deerflow.agents.memory.queue.time.sleep"),
):
queue.flush()
assert mock_updater.update_memory.call_count == 2
mock_updater.update_memory.assert_has_calls(
[
call(
messages=["agent-a"],
thread_id="thread-1",
agent_name="agent-a",
correction_detected=False,
reinforcement_detected=False,
user_id=None,
),
call(
messages=["agent-b"],
thread_id="thread-1",
agent_name="agent-b",
correction_detected=False,
reinforcement_detected=False,
user_id=None,
),
]
)
@@ -3,6 +3,7 @@
from unittest.mock import MagicMock, patch
from deerflow.agents.memory.queue import ConversationContext, MemoryUpdateQueue
from deerflow.config.memory_config import MemoryConfig
def test_conversation_context_has_user_id():
@@ -17,7 +18,7 @@ def test_conversation_context_user_id_default_none():
def test_queue_add_stores_user_id():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
assert len(q._queue) == 1
assert q._queue[0].user_id == "alice"
@@ -26,7 +27,7 @@ def test_queue_add_stores_user_id():
def test_queue_process_passes_user_id_to_updater():
q = MemoryUpdateQueue()
with patch.object(q, "_reset_timer"):
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="t1", messages=["msg"], user_id="alice")
mock_updater = MagicMock()
@@ -37,3 +38,42 @@ def test_queue_process_passes_user_id_to_updater():
mock_updater.update_memory.assert_called_once()
call_kwargs = mock_updater.update_memory.call_args.kwargs
assert call_kwargs["user_id"] == "alice"
def test_queue_keeps_updates_for_different_users_in_same_thread_and_agent():
q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
q.add(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
assert q.pending_count == 2
assert [context.user_id for context in q._queue] == ["alice", "bob"]
assert [context.messages for context in q._queue] == [["alice update"], ["bob update"]]
def test_queue_still_coalesces_updates_for_same_user_thread_and_agent():
q = MemoryUpdateQueue()
with patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)), patch.object(q, "_reset_timer"):
q.add(thread_id="main", messages=["first"], agent_name="researcher", user_id="alice")
q.add(thread_id="main", messages=["second"], agent_name="researcher", user_id="alice")
assert q.pending_count == 1
assert q._queue[0].messages == ["second"]
assert q._queue[0].user_id == "alice"
assert q._queue[0].agent_name == "researcher"
def test_add_nowait_keeps_different_users_separate():
q = MemoryUpdateQueue()
with (
patch("deerflow.agents.memory.queue.get_memory_config", return_value=MemoryConfig(enabled=True)),
patch.object(q, "_schedule_timer"),
):
q.add_nowait(thread_id="main", messages=["alice update"], agent_name="researcher", user_id="alice")
q.add_nowait(thread_id="main", messages=["bob update"], agent_name="researcher", user_id="bob")
assert q.pending_count == 2
assert [context.user_id for context in q._queue] == ["alice", "bob"]
+33
View File
@@ -268,6 +268,39 @@ class TestEdgeCases:
class TestDbRunEventStore:
"""Tests for DbRunEventStore with temp SQLite."""
@pytest.mark.anyio
async def test_postgres_max_seq_uses_advisory_lock_without_for_update(self):
from sqlalchemy.dialects import postgresql
from deerflow.runtime.events.store.db import DbRunEventStore
class FakeSession:
def __init__(self):
self.dialect = postgresql.dialect()
self.execute_calls = []
self.scalar_stmt = None
def get_bind(self):
return self
async def execute(self, stmt, params=None):
self.execute_calls.append((stmt, params))
async def scalar(self, stmt):
self.scalar_stmt = stmt
return 41
session = FakeSession()
max_seq = await DbRunEventStore._max_seq_for_thread(session, "thread-1")
assert max_seq == 41
assert session.execute_calls
assert session.execute_calls[0][1] == {"thread_id": "thread-1"}
assert "pg_advisory_xact_lock" in str(session.execute_calls[0][0])
compiled = str(session.scalar_stmt.compile(dialect=postgresql.dialect()))
assert "FOR UPDATE" not in compiled
@pytest.mark.anyio
async def test_basic_crud(self, tmp_path):
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
+48
View File
@@ -3,7 +3,10 @@
Uses a temp SQLite DB to test ORM-backed CRUD operations.
"""
import re
import pytest
from sqlalchemy.dialects import postgresql
from deerflow.persistence.run import RunRepository
@@ -278,3 +281,48 @@ class TestRunRepository:
assert row4["model_name"] is None
await _cleanup()
@pytest.mark.anyio
async def test_aggregate_tokens_by_thread_reuses_shared_model_name_expression(self):
captured = []
class FakeResult:
def all(self):
return []
class FakeSession:
async def execute(self, stmt):
captured.append(stmt)
return FakeResult()
class FakeSessionContext:
async def __aenter__(self):
return FakeSession()
async def __aexit__(self, exc_type, exc, tb):
return None
repo = RunRepository(lambda: FakeSessionContext())
agg = await repo.aggregate_tokens_by_thread("t1")
assert agg == {
"total_tokens": 0,
"total_input_tokens": 0,
"total_output_tokens": 0,
"total_runs": 0,
"by_model": {},
"by_caller": {"lead_agent": 0, "subagent": 0, "middleware": 0},
}
assert len(captured) == 1
stmt = captured[0]
compiled_sql = str(stmt.compile(dialect=postgresql.dialect()))
select_sql, group_by_sql = compiled_sql.split(" GROUP BY ", maxsplit=1)
model_expr_pattern = r"coalesce\(runs\.model_name, %\(([^)]+)\)s\)"
select_match = re.search(model_expr_pattern + r" AS model", select_sql)
group_by_match = re.fullmatch(model_expr_pattern, group_by_sql.strip())
assert select_match is not None
assert group_by_match is not None
assert select_match.group(1) == group_by_match.group(1)
+81
View File
@@ -0,0 +1,81 @@
from __future__ import annotations
import os
from pathlib import Path
import pytest
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.dialects import mysql, postgresql
from sqlalchemy.types import JSON
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from store.persistence.json_compat import json_match
def _table():
metadata = MetaData()
return Table("t", metadata, Column("data", JSON), Column("id", String))
def test_storage_json_match_compiles_sqlite() -> None:
from sqlalchemy import create_engine
table = _table()
dialect = create_engine("sqlite://").dialect
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'null'")
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'true'")
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "= 'integer'" in int_sql
assert "CAST" in int_sql
float_sql = str(json_match(table.c.data, "k", 3.14).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "IN ('integer', 'real')" in float_sql
assert "REAL" in float_sql
def test_storage_json_match_compiles_postgres() -> None:
table = _table()
dialect = postgresql.dialect()
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_typeof(t.data -> 'k') = 'null'")
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')")
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "CASE WHEN" in int_sql
assert "BIGINT" in int_sql
assert "'^-?[0-9]+$'" in int_sql
def test_storage_json_match_compiles_mysql() -> None:
table = _table()
dialect = mysql.dialect()
null_sql = str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert null_sql == "JSON_TYPE(JSON_EXTRACT(t.data, '$.\"k\"')) = 'NULL'"
bool_sql = str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "JSON_TYPE(JSON_EXTRACT" in bool_sql
assert "= 'BOOLEAN'" in bool_sql
assert "= 'true'" in bool_sql
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "= 'INTEGER'" in int_sql
assert "SIGNED" in int_sql
def test_storage_json_match_rejects_unsafe_keys_and_values() -> None:
table = _table()
for bad_key in ["a.b", "bad;key", "with space", "", 42, None]:
with pytest.raises(ValueError, match="JsonMatch key must match"):
json_match(table.c.data, bad_key, "x") # type: ignore[arg-type]
for bad_value in [[], {}, object()]:
with pytest.raises(TypeError, match="JsonMatch value must be"):
json_match(table.c.data, "k", bad_value)
with pytest.raises(TypeError, match="out of signed 64-bit range"):
json_match(table.c.data, "k", 2**63)
@@ -0,0 +1,122 @@
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
from types import SimpleNamespace
import pytest
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from store.config.storage_config import StorageConfig
from store.persistence.factory import _create_database_url, storage_config_from_database_config
def test_database_sqlite_config_maps_to_storage_config(tmp_path):
database = SimpleNamespace(
backend="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=True,
pool_size=9,
)
storage = storage_config_from_database_config(database)
assert storage == StorageConfig(
driver="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=True,
pool_size=9,
)
assert storage.sqlite_storage_path == str(tmp_path / "deerflow.db")
def test_database_memory_config_is_not_a_storage_backend():
database = SimpleNamespace(backend="memory")
with pytest.raises(ValueError, match="Unsupported database backend"):
storage_config_from_database_config(database)
def test_database_postgres_config_preserves_url_and_pool_options():
database = SimpleNamespace(
backend="postgres",
postgres_url="postgresql://user:pass@db.example:5544/deerflow",
echo_sql=True,
pool_size=11,
)
storage = storage_config_from_database_config(database)
url = _create_database_url(storage)
assert storage.driver == "postgres"
assert storage.database_url == "postgresql://user:pass@db.example:5544/deerflow"
assert storage.username == "user"
assert storage.password == "pass"
assert storage.host == "db.example"
assert storage.port == 5544
assert storage.db_name == "deerflow"
assert storage.echo_sql is True
assert storage.pool_size == 11
assert url.drivername == "postgresql+asyncpg"
assert url.database == "deerflow"
def test_mysql_database_url_is_normalized_to_async_driver():
storage = StorageConfig(
driver="mysql",
database_url="mysql://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+aiomysql"
assert url.database == "deerflow"
def test_mysql_async_database_url_is_preserved():
storage = StorageConfig(
driver="mysql",
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+asyncmy"
assert url.database == "deerflow"
def test_database_postgres_requires_url():
database = SimpleNamespace(backend="postgres", postgres_url="")
with pytest.raises(ValueError, match="database.postgres_url is required"):
storage_config_from_database_config(database)
def test_unsupported_database_backend_rejected():
database = SimpleNamespace(backend="oracle")
with pytest.raises(ValueError, match="Unsupported database backend"):
storage_config_from_database_config(database)
def test_storage_models_import_without_config_file(tmp_path):
env = os.environ.copy()
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
result = subprocess.run(
[
sys.executable,
"-c",
"from store.persistence.base_model import UniversalText, id_key; from store.repositories.models import RunEvent; print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
],
check=False,
capture_output=True,
env=env,
text=True,
)
assert result.returncode == 0, result.stderr
assert "UniversalText run_events" in result.stdout
@@ -0,0 +1,58 @@
from __future__ import annotations
import asyncio
import os
from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from sqlalchemy import inspect
from store.persistence import create_persistence_from_database_config
from store.repositories import UserCreate, build_user_repository
def test_sqlite_persistence_from_database_config_creates_storage_tables(tmp_path):
async def run() -> None:
persistence = await create_persistence_from_database_config(
SimpleNamespace(
backend="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=False,
pool_size=5,
)
)
assert persistence is not None
try:
await persistence.setup()
async with persistence.engine.connect() as conn:
tables = await conn.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names()))
assert {
"users",
"runs",
"run_events",
"threads_meta",
"feedback",
}.issubset(tables)
async with persistence.session_factory() as session:
repo = build_user_repository(session)
user = await repo.create_user(
UserCreate(
id=str(uuid4()),
email="storage-user@example.com",
password_hash="hash",
)
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_user_repository(session)
assert await repo.get_user_by_id(user.id) == user
finally:
await persistence.aclose()
asyncio.run(run())
+395
View File
@@ -0,0 +1,395 @@
from __future__ import annotations
import os
from datetime import UTC, datetime, timedelta
from pathlib import Path
from types import SimpleNamespace
import pytest
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from store.persistence import create_persistence_from_database_config
from store.repositories import (
FeedbackCreate,
InvalidMetadataFilterError,
RunCreate,
RunEventCreate,
ThreadMetaCreate,
build_feedback_repository,
build_run_event_repository,
build_run_repository,
build_thread_meta_repository,
)
async def _make_persistence(tmp_path):
persistence = await create_persistence_from_database_config(
SimpleNamespace(
backend="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=False,
pool_size=5,
)
)
await persistence.setup()
return persistence
@pytest.mark.anyio
async def test_storage_run_repository_filters_and_aggregates(tmp_path):
persistence = await _make_persistence(tmp_path)
old = datetime.now(UTC) - timedelta(hours=1)
newer = datetime.now(UTC)
try:
async with persistence.session_factory() as session:
repo = build_run_repository(session)
await repo.create_run(
RunCreate(
run_id="run-old",
thread_id="thread-1",
user_id="alice",
status="pending",
model_name="model-a",
metadata={"kind": "draft"},
kwargs={"temperature": 0.2},
created_time=old,
)
)
await repo.create_run(
RunCreate(
run_id="run-new",
thread_id="thread-1",
user_id="bob",
status="running",
model_name="model-b",
error="queued",
created_time=newer,
)
)
await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running"))
await repo.update_run_completion(
"run-old",
status="success",
total_input_tokens=7,
total_output_tokens=3,
total_tokens=10,
llm_call_count=1,
lead_agent_tokens=8,
subagent_tokens=2,
first_human_message="hello",
last_ai_message="world",
)
await repo.update_run_completion(
"run-new",
status="error",
total_tokens=5,
middleware_tokens=5,
error="failed",
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_run_repository(session)
fetched = await repo.get_run("run-old")
assert fetched is not None
assert fetched.metadata == {"kind": "draft"}
assert fetched.kwargs == {"temperature": 0.2}
assert fetched.first_human_message == "hello"
assert fetched.last_ai_message == "world"
all_thread_runs = await repo.list_runs_by_thread("thread-1")
assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"]
alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice")
assert [run.run_id for run in alice_runs] == ["run-old"]
pending = await repo.list_pending(before=datetime.now(UTC).isoformat())
assert [run.run_id for run in pending] == []
agg = await repo.aggregate_tokens_by_thread("thread-1")
assert agg["total_tokens"] == 15
assert agg["total_input_tokens"] == 7
assert agg["total_output_tokens"] == 3
assert agg["total_runs"] == 2
assert agg["by_model"] == {
"model-a": {"tokens": 10, "runs": 1},
"model-b": {"tokens": 5, "runs": 1},
}
assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5}
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
await repo.create_thread_meta(
ThreadMetaCreate(
thread_id="thread-1",
assistant_id="agent-a",
user_id="alice",
display_name="Initial",
status="idle",
metadata={"topic": "finance", "region": "cn"},
)
)
await repo.create_thread_meta(
ThreadMetaCreate(
thread_id="thread-2",
assistant_id="agent-b",
user_id="bob",
status="running",
metadata={"topic": "legal"},
)
)
await repo.update_thread_meta(
"thread-1",
display_name="Updated",
status="running",
metadata={"topic": "finance", "region": "us"},
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
fetched = await repo.get_thread_meta("thread-1")
assert fetched is not None
assert fetched.display_name == "Updated"
assert fetched.status == "running"
assert fetched.metadata == {"topic": "finance", "region": "us"}
by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice")
assert [thread.thread_id for thread in by_metadata] == ["thread-1"]
by_assistant = await repo.search_threads(assistant_id="agent-b")
assert [thread.thread_id for thread in by_assistant] == ["thread-2"]
await repo.delete_thread("thread-1")
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
assert await repo.get_thread_meta("thread-1") is None
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_thread_meta_metadata_filters_are_type_safe(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-true", metadata={"value": True}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-false", metadata={"value": False}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="int-one", metadata={"value": 1}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="null-value", metadata={"value": None}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="missing-value", metadata={"other": "x"}))
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
assert [row.thread_id for row in await repo.search_threads(metadata={"value": True})] == ["bool-true"]
assert [row.thread_id for row in await repo.search_threads(metadata={"value": False})] == ["bool-false"]
assert [row.thread_id for row in await repo.search_threads(metadata={"value": 1})] == ["int-one"]
assert [row.thread_id for row in await repo.search_threads(metadata={"value": None})] == ["null-value"]
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_thread_meta_metadata_filters_paginate_after_sql_match(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
for index in range(30):
metadata = {"target": "yes"} if index % 3 == 0 else {"target": "no"}
await repo.create_thread_meta(ThreadMetaCreate(thread_id=f"thread-{index:02d}", metadata=metadata))
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
first_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=0)
second_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=3)
last_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=9)
assert len(first_page) == 3
assert len(second_page) == 3
assert len(last_page) == 1
assert {row.thread_id for row in first_page}.isdisjoint({row.thread_id for row in second_page})
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_thread_meta_metadata_filter_rejects_invalid_entries(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-1", metadata={"env": "prod"}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-2", metadata={"env": "staging"}))
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
partial = await repo.search_threads(metadata={"env": "prod", "bad;key": "ignored"})
assert [row.thread_id for row in partial] == ["thread-1"]
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search_threads(metadata={"bad;key": "x"})
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search_threads(metadata={"env": ["prod", "staging"]})
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_feedback_repository(session)
first = await repo.create_feedback(
FeedbackCreate(
feedback_id="fb-1",
run_id="run-1",
thread_id="thread-1",
rating=1,
user_id="alice",
message_id="msg-1",
comment="good",
)
)
second = await repo.create_feedback(
FeedbackCreate(
feedback_id="fb-2",
run_id="run-1",
thread_id="thread-1",
rating=-1,
user_id="bob",
)
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_feedback_repository(session)
assert await repo.get_feedback(first.feedback_id) == first
assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [
second.feedback_id,
first.feedback_id,
]
assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == {
"fb-1",
"fb-2",
}
assert await repo.delete_feedback("fb-1") is True
assert await repo.delete_feedback("missing") is False
with pytest.raises(ValueError, match="rating must be"):
await repo.create_feedback(
FeedbackCreate(
feedback_id="fb-bad",
run_id="run-1",
thread_id="thread-1",
rating=0,
)
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_feedback_repository(session)
assert await repo.get_feedback("fb-1") is None
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
rows = await repo.append_batch(
[
RunEventCreate(
thread_id="thread-1",
run_id="run-1",
user_id="alice",
event_type="message",
category="message",
content={"role": "user", "content": "hello"},
metadata={"source": "input"},
),
RunEventCreate(
thread_id="thread-1",
run_id="run-1",
event_type="tool",
category="debug",
content="tool-call",
),
RunEventCreate(
thread_id="thread-1",
run_id="run-2",
event_type="message",
category="message",
content="second",
),
RunEventCreate(
thread_id="thread-2",
run_id="run-3",
event_type="message",
category="message",
content="other-thread",
),
]
)
await session.commit()
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
assert rows[1].seq == rows[0].seq + 1
assert rows[2].seq == rows[1].seq + 1
assert rows[0].content == {"role": "user", "content": "hello"}
assert rows[0].metadata == {"source": "input", "content_is_json": True}
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
messages = await repo.list_messages("thread-1", limit=2)
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
assert await repo.count_messages("thread-1") == 2
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
assert [event.seq for event in after] == [rows[0].seq]
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
assert [event.seq for event in before] == [rows[0].seq]
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
assert [event.content for event in events] == ["tool-call"]
assert await repo.delete_by_run("thread-1", "run-1") == 2
assert await repo.delete_by_thread("thread-2") == 1
await session.commit()
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
remaining = await repo.list_events("thread-1", "run-2")
assert [event.seq for event in remaining] == [rows[2].seq]
assert await repo.count_messages("thread-2") == 0
later = await repo.append_batch(
[
RunEventCreate(
thread_id="thread-1",
run_id="run-4",
event_type="message",
category="message",
content="after-delete",
)
]
)
assert later[0].seq > rows[2].seq
finally:
await persistence.aclose()
@@ -0,0 +1,177 @@
from __future__ import annotations
import asyncio
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from uuid import uuid4
import pytest
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.repositories import UserCreate, UserNotFoundError, build_user_repository
from store.repositories.models import User as UserModel
@asynccontextmanager
async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]:
db_path = tmp_path / "storage-users.db"
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
async with engine.begin() as conn:
await conn.run_sync(UserModel.metadata.create_all)
try:
yield async_sessionmaker(engine, expire_on_commit=False)
finally:
await engine.dispose()
async def _create_user(
session_factory: async_sessionmaker[AsyncSession],
*,
email: str = "user@example.com",
system_role: str = "user",
oauth_provider: str | None = None,
oauth_id: str | None = None,
):
async with session_factory() as session:
repo = build_user_repository(session)
user = await repo.create_user(
UserCreate(
id=str(uuid4()),
email=email,
password_hash="hash",
system_role=system_role, # type: ignore[arg-type]
oauth_provider=oauth_provider,
oauth_id=oauth_id,
)
)
await session.commit()
return user
def test_create_and_get_user_by_id_and_email(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
created = await _create_user(session_factory)
async with session_factory() as session:
repo = build_user_repository(session)
by_id = await repo.get_user_by_id(created.id)
by_email = await repo.get_user_by_email(created.email)
assert by_id == created
assert by_email == created
assert created.system_role == "user"
assert created.needs_setup is False
assert created.token_version == 0
asyncio.run(run())
def test_duplicate_email_raises_value_error(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
await _create_user(session_factory, email="dupe@example.com")
async with session_factory() as session:
repo = build_user_repository(session)
with pytest.raises(ValueError, match="Email already registered"):
await repo.create_user(
UserCreate(
id=str(uuid4()),
email="dupe@example.com",
password_hash="hash",
)
)
asyncio.run(run())
def test_oauth_lookup_and_plain_users_without_oauth(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
await _create_user(session_factory, email="local-1@example.com")
await _create_user(session_factory, email="local-2@example.com")
oauth_user = await _create_user(
session_factory,
email="oauth@example.com",
oauth_provider="github",
oauth_id="gh-123",
)
async with session_factory() as session:
repo = build_user_repository(session)
assert await repo.count_users() == 3
assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user
assert await repo.get_user_by_oauth("github", "missing") is None
asyncio.run(run())
def test_count_admins_and_get_first_admin(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
await _create_user(session_factory, email="user@example.com")
admin = await _create_user(
session_factory,
email="admin@example.com",
system_role="admin",
)
async with session_factory() as session:
repo = build_user_repository(session)
assert await repo.count_users() == 2
assert await repo.count_admin_users() == 1
assert await repo.get_first_admin() == admin
asyncio.run(run())
def test_update_user_round_trips_token_version_and_setup_state(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
created = await _create_user(session_factory)
updated = created.model_copy(
update={
"email": "renamed@example.com",
"token_version": 4,
"needs_setup": True,
}
)
async with session_factory() as session:
repo = build_user_repository(session)
saved = await repo.update_user(updated)
await session.commit()
async with session_factory() as session:
repo = build_user_repository(session)
fetched = await repo.get_user_by_id(created.id)
assert saved.email == "renamed@example.com"
assert fetched == updated
asyncio.run(run())
def test_update_missing_user_raises(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
missing = UserCreate(id=str(uuid4()), email="missing@example.com")
async with session_factory() as session:
repo = build_user_repository(session)
created_shape = await repo.create_user(missing)
await session.rollback()
with pytest.raises(UserNotFoundError):
await repo.update_user(created_shape)
asyncio.run(run())
+26 -1
View File
@@ -30,12 +30,18 @@ def _dynamic_context_reminder(msg_id: str = "reminder-1") -> HumanMessage:
)
def _runtime(thread_id: str | None = "thread-1", agent_name: str | None = None) -> SimpleNamespace:
def _runtime(
thread_id: str | None = "thread-1",
agent_name: str | None = None,
user_id: str | None = None,
) -> SimpleNamespace:
context = {}
if thread_id is not None:
context["thread_id"] = thread_id
if agent_name is not None:
context["agent_name"] = agent_name
if user_id is not None:
context["user_id"] = user_id
return SimpleNamespace(context=context)
@@ -634,3 +640,22 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_queue", lambda: queue)
memory_flush_hook(
SummarizationEvent(
messages_to_summarize=tuple(_messages()[:2]),
preserved_messages=(),
thread_id="main",
agent_name="researcher",
runtime=_runtime(thread_id="main", agent_name="researcher", user_id="alice"),
)
)
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
+154
View File
@@ -59,12 +59,15 @@ def _make_result(
ai_messages: list[dict] | None = None,
result: str | None = None,
error: str | None = None,
token_usage_records: list[dict] | None = None,
) -> SimpleNamespace:
return SimpleNamespace(
status=status,
ai_messages=ai_messages or [],
result=result,
error=error,
token_usage_records=token_usage_records or [],
usage_reported=False,
)
@@ -1132,3 +1135,154 @@ def test_cancellation_reports_subagent_usage(monkeypatch):
assert len(report_calls) == 1
assert report_calls[0][1] is cancel_result
assert cleanup_calls == ["tc-cancel-report"]
@pytest.mark.parametrize(
"status, expected_type",
[
(FakeSubagentStatus.COMPLETED, "task_completed"),
(FakeSubagentStatus.FAILED, "task_failed"),
(FakeSubagentStatus.CANCELLED, "task_cancelled"),
(FakeSubagentStatus.TIMED_OUT, "task_timed_out"),
],
)
def test_terminal_events_include_usage(monkeypatch, status, expected_type):
"""Terminal task events include a usage summary from token_usage_records."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
records = [
{"source_run_id": "r1", "caller": "subagent:general-purpose", "input_tokens": 100, "output_tokens": 50, "total_tokens": 150},
{"source_run_id": "r2", "caller": "subagent:general-purpose", "input_tokens": 200, "output_tokens": 80, "total_tokens": 280},
]
result = _make_result(status, result="ok" if status == FakeSubagentStatus.COMPLETED else None, error="err" if status != FakeSubagentStatus.COMPLETED else None, token_usage_records=records)
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-usage",
)
terminal_events = [e for e in events if e["type"] == expected_type]
assert len(terminal_events) == 1
assert terminal_events[0]["usage"] == {
"input_tokens": 300,
"output_tokens": 130,
"total_tokens": 430,
}
def test_terminal_event_usage_none_when_no_records(monkeypatch):
"""Terminal event has usage=None when token_usage_records is empty."""
config = _make_subagent_config()
runtime = _make_runtime()
events = []
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=[])
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _: config)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: events.append)
monkeypatch.setattr(task_tool_module.asyncio, "sleep", _no_sleep)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-no-records",
)
completed = [e for e in events if e["type"] == "task_completed"]
assert len(completed) == 1
assert completed[0]["usage"] is None
@pytest.mark.parametrize("error", [FileNotFoundError("missing config"), ValueError("invalid config")])
def test_subagent_usage_cache_is_skipped_when_default_config_cannot_load(monkeypatch, error):
monkeypatch.setattr(
task_tool_module,
"get_app_config",
MagicMock(side_effect=error),
)
assert task_tool_module._token_usage_cache_enabled(None) is False
def test_subagent_usage_cache_is_skipped_when_token_usage_is_disabled(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=False))
runtime = _make_runtime(app_config=app_config)
records = [{"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}]
result = _make_result(FakeSubagentStatus.COMPLETED, result="done", token_usage_records=records)
task_tool_module._subagent_usage_cache.clear()
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", lambda _: result)
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr(task_tool_module, "_report_subagent_usage", lambda *_: None)
monkeypatch.setattr(task_tool_module, "cleanup_background_task", lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-disabled-cache",
)
assert task_tool_module.pop_cached_subagent_usage("tc-disabled-cache") is None
def test_subagent_usage_cache_is_cleared_when_polling_raises(monkeypatch):
config = _make_subagent_config()
app_config = SimpleNamespace(token_usage=SimpleNamespace(enabled=True))
runtime = _make_runtime(app_config=app_config)
task_tool_module._subagent_usage_cache["tc-error"] = {"input_tokens": 1, "output_tokens": 1, "total_tokens": 2}
monkeypatch.setattr(task_tool_module, "SubagentStatus", FakeSubagentStatus)
monkeypatch.setattr(task_tool_module, "get_available_subagent_names", lambda *, app_config: ["general-purpose"])
monkeypatch.setattr(task_tool_module, "get_subagent_config", lambda _, *, app_config: config)
monkeypatch.setattr(
task_tool_module,
"SubagentExecutor",
type("DummyExecutor", (), {"__init__": lambda self, **kwargs: None, "execute_async": lambda self, prompt, task_id=None: task_id}),
)
monkeypatch.setattr(task_tool_module, "get_background_task_result", MagicMock(side_effect=RuntimeError("poll failed")))
monkeypatch.setattr(task_tool_module, "get_stream_writer", lambda: lambda _: None)
monkeypatch.setattr("deerflow.tools.get_available_tools", MagicMock(return_value=[]))
with pytest.raises(RuntimeError, match="poll failed"):
_run_task_tool(
runtime=runtime,
description="test",
prompt="do work",
subagent_type="general-purpose",
tool_call_id="tc-error",
)
assert task_tool_module.pop_cached_subagent_usage("tc-error") is None
+48 -1
View File
@@ -1,9 +1,10 @@
"""Tests for TokenUsageMiddleware attribution annotations."""
import importlib
import logging
from unittest.mock import MagicMock
from langchain_core.messages import AIMessage
from langchain_core.messages import AIMessage, ToolMessage
from deerflow.agents.middlewares.token_usage_middleware import (
TOKEN_USAGE_ATTRIBUTION_KEY,
@@ -232,3 +233,49 @@ class TestTokenUsageMiddleware:
"tool_call_id": "write_todos:remove",
}
]
def test_merges_subagent_usage_by_message_position_when_ai_message_ids_are_missing(self, monkeypatch):
middleware = TokenUsageMiddleware()
first_dispatch = AIMessage(
content="",
tool_calls=[{"id": "task:first", "name": "task", "args": {}}],
)
second_dispatch = AIMessage(
content="",
tool_calls=[
{"id": "task:second-a", "name": "task", "args": {}},
{"id": "task:second-b", "name": "task", "args": {}},
],
)
messages = [
first_dispatch,
ToolMessage(content="first", tool_call_id="task:first"),
second_dispatch,
ToolMessage(content="second-a", tool_call_id="task:second-a"),
ToolMessage(content="second-b", tool_call_id="task:second-b"),
AIMessage(content="done"),
]
cached_usage = {
"task:second-a": {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15},
"task:second-b": {"input_tokens": 20, "output_tokens": 7, "total_tokens": 27},
}
task_tool_module = importlib.import_module("deerflow.tools.builtins.task_tool")
monkeypatch.setattr(
task_tool_module,
"pop_cached_subagent_usage",
lambda tool_call_id: cached_usage.pop(tool_call_id, None),
)
result = middleware.after_model({"messages": messages}, _make_runtime())
assert result is not None
usage_updates = [message for message in result["messages"] if getattr(message, "usage_metadata", None)]
assert len(usage_updates) == 1
updated = usage_updates[0]
assert updated.tool_calls == second_dispatch.tool_calls
assert updated.usage_metadata == {
"input_tokens": 30,
"output_tokens": 12,
"total_tokens": 42,
}
+4 -8
View File
@@ -65,8 +65,7 @@ def _make_minimal_config(tools):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash, mock_cfg):
def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_bash, mock_cfg):
"""Config-loaded async-only tools can still be invoked by sync clients."""
async def async_tool_impl(x: int) -> str:
@@ -98,8 +97,7 @@ def test_config_loaded_async_only_tool_gets_sync_wrapper(mock_reset, mock_bash,
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
def test_no_duplicates_returned(mock_bash, mock_cfg):
"""get_available_tools() never returns two tools with the same name."""
mock_cfg.return_value = _make_minimal_config([])
@@ -113,8 +111,7 @@ def test_no_duplicates_returned(mock_reset, mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
def test_first_occurrence_wins(mock_bash, mock_cfg):
"""When duplicates exist, the first occurrence is kept."""
mock_cfg.return_value = _make_minimal_config([])
@@ -132,8 +129,7 @@ def test_first_occurrence_wins(mock_reset, mock_bash, mock_cfg):
@patch("deerflow.tools.tools.get_app_config")
@patch("deerflow.tools.tools.is_host_bash_allowed", return_value=True)
@patch("deerflow.tools.tools.reset_deferred_registry")
def test_duplicate_triggers_warning(mock_reset, mock_bash, mock_cfg, caplog):
def test_duplicate_triggers_warning(mock_bash, mock_cfg, caplog):
"""A warning is logged for every skipped duplicate."""
import logging
+96 -4
View File
@@ -14,6 +14,7 @@ resolution-markers = [
members = [
"deer-flow",
"deerflow-harness",
"deerflow-storage",
]
[[package]]
@@ -136,6 +137,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441, upload-time = "2026-03-31T22:00:12.791Z" },
]
[[package]]
name = "aiomysql"
version = "0.3.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pymysql" },
]
sdist = { url = "https://files.pythonhosted.org/packages/29/e0/302aeffe8d90853556f47f3106b89c16cc2ec2a4d269bdfd82e3f4ae12cc/aiomysql-0.3.2.tar.gz", hash = "sha256:72d15ef5cfc34c03468eb41e1b90adb9fd9347b0b589114bd23ead569a02ac1a", size = 108311, upload-time = "2025-10-22T00:15:21.278Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4c/af/aae0153c3e28712adaf462328f6c7a3c196a1c1c27b491de4377dd3e6b52/aiomysql-0.3.2-py3-none-any.whl", hash = "sha256:c82c5ba04137d7afd5c693a258bea8ead2aad77101668044143a991e04632eb2", size = 71834, upload-time = "2025-10-22T00:15:15.905Z" },
]
[[package]]
name = "aiosignal"
version = "1.4.0"
@@ -746,6 +759,7 @@ source = { virtual = "." }
dependencies = [
{ name = "bcrypt" },
{ name = "deerflow-harness" },
{ name = "deerflow-storage" },
{ name = "dingtalk-stream" },
{ name = "email-validator" },
{ name = "fastapi" },
@@ -763,8 +777,12 @@ dependencies = [
]
[package.optional-dependencies]
mysql = [
{ name = "deerflow-storage", extra = ["mysql"] },
]
postgres = [
{ name = "deerflow-harness", extra = ["postgres"] },
{ name = "deerflow-storage", extra = ["postgres"] },
]
[package.dev-dependencies]
@@ -780,6 +798,9 @@ requires-dist = [
{ name = "bcrypt", specifier = ">=4.0.0" },
{ name = "deerflow-harness", editable = "packages/harness" },
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
{ name = "deerflow-storage", editable = "packages/storage" },
{ name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" },
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
{ name = "email-validator", specifier = ">=2.0.0" },
{ name = "fastapi", specifier = ">=0.115.0" },
@@ -795,7 +816,7 @@ requires-dist = [
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
]
provides-extras = ["postgres"]
provides-extras = ["postgres", "mysql"]
[package.metadata.requires-dev]
dev = [
@@ -901,6 +922,54 @@ requires-dist = [
]
provides-extras = ["ollama", "postgres", "pymupdf"]
[[package]]
name = "deerflow-storage"
version = "0.1.0"
source = { editable = "packages/storage" }
dependencies = [
{ name = "alembic" },
{ name = "dotenv" },
{ name = "langgraph" },
{ name = "pydantic" },
{ name = "pyyaml" },
{ name = "sqlalchemy", extra = ["asyncio"] },
]
[package.optional-dependencies]
mysql = [
{ name = "aiomysql" },
{ name = "langgraph-checkpoint-mysql" },
]
postgres = [
{ name = "asyncpg" },
{ name = "langgraph-checkpoint-postgres" },
{ name = "psycopg", extra = ["binary"] },
{ name = "psycopg-pool" },
]
sqlite = [
{ name = "aiosqlite" },
{ name = "langgraph-checkpoint-sqlite" },
]
[package.metadata]
requires-dist = [
{ name = "aiomysql", marker = "extra == 'mysql'", specifier = ">=0.2" },
{ name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.22.1" },
{ name = "alembic", specifier = ">=1.13" },
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
{ name = "dotenv", specifier = ">=0.9.9" },
{ name = "langgraph", specifier = ">=1.1.9" },
{ name = "langgraph-checkpoint-mysql", marker = "extra == 'mysql'", specifier = ">=3.0.0" },
{ name = "langgraph-checkpoint-postgres", marker = "extra == 'postgres'", specifier = ">=3.0.5" },
{ name = "langgraph-checkpoint-sqlite", marker = "extra == 'sqlite'", specifier = ">=3.0.3" },
{ name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.3.3" },
{ name = "psycopg-pool", marker = "extra == 'postgres'", specifier = ">=3.3.0" },
{ name = "pydantic", specifier = ">=2.12.5" },
{ name = "pyyaml", specifier = ">=6.0.3" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<3.0" },
]
provides-extras = ["postgres", "mysql", "sqlite"]
[[package]]
name = "defusedxml"
version = "0.7.1"
@@ -1914,6 +1983,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b9/5a/6dba29dd89b0a46ae21c707da0f9d17e94f27d3e481ed15bc99d6bd20aa6/langgraph_checkpoint-4.0.2-py3-none-any.whl", hash = "sha256:59b0f29216128a629c58dd07c98aa004f82f51805d5573126ffb419b753ff253", size = 51000, upload-time = "2026-04-15T21:02:59.096Z" },
]
[[package]]
name = "langgraph-checkpoint-mysql"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langgraph-checkpoint" },
{ name = "orjson" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e0/4e/0a6c78e5d3f2ca1525903c2363e721873594b6b77dd83537a6369193c474/langgraph_checkpoint_mysql-3.0.0.tar.gz", hash = "sha256:006aaa089f4c2fbd7b2c113b800ccd3dbb95f92203e656451677256b4b4f880f", size = 213142, upload-time = "2026-01-23T11:11:15.74Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/08/68/343103a7fae05523f9cecabbec2babdb737e66b4bf6ea48ae00c685ed11c/langgraph_checkpoint_mysql-3.0.0-py3-none-any.whl", hash = "sha256:7560ccd16e7596a047e15a307cec12dbd88fdcaab45a75759e5c6adef22a27d1", size = 38009, upload-time = "2026-01-23T11:11:14.697Z" },
]
[[package]]
name = "langgraph-checkpoint-postgres"
version = "3.0.5"
@@ -2005,7 +2088,7 @@ wheels = [
[[package]]
name = "langsmith"
version = "0.7.36"
version = "0.8.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "httpx" },
@@ -2018,9 +2101,9 @@ dependencies = [
{ name = "xxhash" },
{ name = "zstandard" },
]
sdist = { url = "https://files.pythonhosted.org/packages/8d/4c/5f20508000ee0559bfa713b85c431b1cdc95d2913247ff9eb318e7fdff7b/langsmith-0.7.36.tar.gz", hash = "sha256:d18ef34819e0a252cf52c74ce6e9bd5de6deea4f85a3aef50abc9f48d8c5f8b8", size = 4402322, upload-time = "2026-04-24T16:58:06.681Z" }
sdist = { url = "https://files.pythonhosted.org/packages/a8/64/95f1f013531395f4e8ed73caeee780f65c7c58fe028cb543f8937b45611b/langsmith-0.8.0.tar.gz", hash = "sha256:59fe5b2a56bbbe14a08aa76691f84b49e8675dd21e11b57d80c6db8c08bac2e3", size = 4432996, upload-time = "2026-04-30T22:13:07.341Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f3/8d/3ca31ae3a4a437191243ad6d9061ede9367440bb7dc9a0da1ecc2c2a4865/langsmith-0.7.36-py3-none-any.whl", hash = "sha256:e1657a795f3f1982bb8d34c98b143b630ca3eee9de2c10e670c9105233b54654", size = 381808, upload-time = "2026-04-24T16:58:04.572Z" },
{ url = "https://files.pythonhosted.org/packages/f3/e1/a4be2e696c9473bb53298df398237da5674704d781d4b748ed35aeef592a/langsmith-0.8.0-py3-none-any.whl", hash = "sha256:12cc4bc5622b835a6d841964d6034df3617bdb912dae0c1381fd0a68a9b3a3ef", size = 393268, upload-time = "2026-04-30T22:13:05.56Z" },
]
[package.optional-dependencies]
@@ -3442,6 +3525,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" },
]
[[package]]
name = "pymysql"
version = "1.1.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/7f/ec/8d45c920e90445f0b75c590b32851853ed319763b0d8dff8d283052da8cf/pymysql-1.1.3.tar.gz", hash = "sha256:e70ebf2047a4edf6138cf79c68ad418ef620af65900aa585c5e8bfc95044d43a", size = 48207, upload-time = "2026-05-01T09:09:54.532Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/dc/9085f3d6f497e9b25fb40d6e8ecef3ddbb5cf977a949b933624a299f5c16/pymysql-1.1.3-py3-none-any.whl", hash = "sha256:8164ba62c552f6105f3b11753352d0f16b90d1703ba67d81923d5a8a5d1c5289", size = 45356, upload-time = "2026-05-01T09:09:53.316Z" },
]
[[package]]
name = "pypdfium2"
version = "5.7.1"
+3 -4
View File
@@ -10,7 +10,6 @@ import { FlickeringGrid } from "@/components/ui/flickering-grid";
import { Input } from "@/components/ui/input";
import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
/**
* Validate next parameter
@@ -72,7 +71,7 @@ export default function LoginPage() {
useEffect(() => {
let cancelled = false;
void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
void fetch("/api/v1/auth/setup-status")
.then((r) => r.json())
.then((data: { needs_setup?: boolean }) => {
if (!cancelled && data.needs_setup) {
@@ -95,8 +94,8 @@ export default function LoginPage() {
try {
const endpoint = isLogin
? `${getBackendBaseURL()}/api/v1/auth/login/local`
: `${getBackendBaseURL()}/api/v1/auth/register`;
? "/api/v1/auth/login/local"
: "/api/v1/auth/register";
const body = isLogin
? `username=${encodeURIComponent(email)}&password=${encodeURIComponent(password)}`
: JSON.stringify({ email, password });
+14 -18
View File
@@ -10,7 +10,6 @@ import { Input } from "@/components/ui/input";
import { getCsrfHeaders } from "@/core/api/fetcher";
import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
type SetupMode = "loading" | "init_admin" | "change_password";
@@ -37,7 +36,7 @@ export default function SetupPage() {
setMode("change_password");
} else if (!isAuthenticated) {
// Check if the system has no users yet
void fetch(`${getBackendBaseURL()}/api/v1/auth/setup-status`)
void fetch("/api/v1/auth/setup-status")
.then((r) => r.json())
.then((data: { needs_setup?: boolean }) => {
if (cancelled) return;
@@ -73,7 +72,7 @@ export default function SetupPage() {
setLoading(true);
try {
const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/initialize`, {
const res = await fetch("/api/v1/auth/initialize", {
method: "POST",
headers: { "Content-Type": "application/json" },
credentials: "include",
@@ -114,22 +113,19 @@ export default function SetupPage() {
setLoading(true);
try {
const res = await fetch(
`${getBackendBaseURL()}/api/v1/auth/change-password`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
credentials: "include",
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
new_email: email || undefined,
}),
const res = await fetch("/api/v1/auth/change-password", {
method: "POST",
headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
);
credentials: "include",
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
new_email: email || undefined,
}),
});
if (!res.ok) {
const data = await res.json();
+1 -2
View File
@@ -4,7 +4,6 @@ import { redirect } from "next/navigation";
import { AuthProvider } from "@/core/auth/AuthProvider";
import { getServerSideUser } from "@/core/auth/server";
import { assertNever } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
import { WorkspaceContent } from "./workspace-content";
@@ -45,7 +44,7 @@ export default async function WorkspaceLayout({
Retry
</Link>
<Link
href={`${getBackendBaseURL()}/api/v1/auth/logout`}
href="/api/v1/auth/logout"
className="text-muted-foreground hover:bg-muted rounded-md border px-4 py-2 text-sm"
>
Logout &amp; Reset
@@ -12,13 +12,11 @@ function TokenUsageSummary({
inputTokens,
outputTokens,
totalTokens,
unavailable = false,
}: {
className?: string;
inputTokens?: number;
outputTokens?: number;
totalTokens?: number;
unavailable?: boolean;
}) {
const { t } = useI18n();
@@ -33,21 +31,15 @@ function TokenUsageSummary({
<CoinsIcon className="size-3" />
{t.tokenUsage.label}
</span>
{!unavailable ? (
<>
<span>
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
</span>
<span>
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
</span>
<span className="font-medium">
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
</span>
</>
) : (
<span>{t.tokenUsage.unavailableShort}</span>
)}
<span>
{t.tokenUsage.input}: {formatTokenCount(inputTokens ?? 0)}
</span>
<span>
{t.tokenUsage.output}: {formatTokenCount(outputTokens ?? 0)}
</span>
<span className="font-medium">
{t.tokenUsage.total}: {formatTokenCount(totalTokens ?? 0)}
</span>
</div>
);
}
@@ -55,7 +47,7 @@ function TokenUsageSummary({
export function MessageTokenUsageList({
className,
enabled = false,
isLoading = false,
isLoading: _isLoading = false,
messages,
}: {
className?: string;
@@ -63,7 +55,7 @@ export function MessageTokenUsageList({
isLoading?: boolean;
messages: Message[];
}) {
if (!enabled || isLoading) {
if (!enabled) {
return null;
}
@@ -75,13 +67,16 @@ export function MessageTokenUsageList({
const usage = accumulateUsage(aiMessages);
if (!usage) {
return null;
}
return (
<TokenUsageSummary
className={className}
inputTokens={usage?.inputTokens}
outputTokens={usage?.outputTokens}
totalTokens={usage?.totalTokens}
unavailable={!usage}
inputTokens={usage.inputTokens}
outputTokens={usage.outputTokens}
totalTokens={usage.totalTokens}
/>
);
}
@@ -8,7 +8,6 @@ import { Input } from "@/components/ui/input";
import { fetch, getCsrfHeaders } from "@/core/api/fetcher";
import { useAuth } from "@/core/auth/AuthProvider";
import { parseAuthError } from "@/core/auth/types";
import { getBackendBaseURL } from "@/core/config";
import { useI18n } from "@/core/i18n/hooks";
import { SettingsSection } from "./settings-section";
@@ -39,20 +38,17 @@ export function AccountSettingsPage() {
setLoading(true);
try {
const res = await fetch(
`${getBackendBaseURL()}/api/v1/auth/change-password`,
{
method: "POST",
headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
}),
const res = await fetch("/api/v1/auth/change-password", {
method: "POST",
headers: {
"Content-Type": "application/json",
...getCsrfHeaders(),
},
);
body: JSON.stringify({
current_password: currentPassword,
new_password: newPassword,
}),
});
if (!res.ok) {
const data = await res.json();
+2 -4
View File
@@ -10,8 +10,6 @@ import React, {
type ReactNode,
} from "react";
import { getBackendBaseURL } from "@/core/config";
import { type User, buildLoginUrl } from "./types";
// Re-export for consumers
@@ -58,7 +56,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
const refreshUser = useCallback(async () => {
try {
setIsLoading(true);
const res = await fetch(`${getBackendBaseURL()}/api/v1/auth/me`, {
const res = await fetch("/api/v1/auth/me", {
credentials: "include",
});
@@ -90,7 +88,7 @@ export function AuthProvider({ children, initialUser }: AuthProviderProps) {
setUser(null);
try {
await fetch(`${getBackendBaseURL()}/api/v1/auth/logout`, {
await fetch("/api/v1/auth/logout", {
method: "POST",
credentials: "include",
});
+2 -2
View File
@@ -65,7 +65,7 @@ export function accumulateUsage(messages: Message[]): TokenUsage | null {
return hasUsage ? cumulative : null;
}
function hasNonZeroUsage(
export function hasNonZeroUsage(
usage: TokenUsage | null | undefined,
): usage is TokenUsage {
return (
@@ -75,7 +75,7 @@ function hasNonZeroUsage(
);
}
function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
export function addUsage(base: TokenUsage, delta: TokenUsage): TokenUsage {
return {
inputTokens: base.inputTokens + delta.inputTokens,
outputTokens: base.outputTokens + delta.outputTokens,
+15 -3
View File
@@ -296,7 +296,11 @@ export function useThreadStream({
onError(error) {
setOptimisticMessages([]);
toast.error(getStreamErrorMessage(error));
pendingUsageBaselineMessageIdsRef.current = new Set();
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
queryKey: threadTokenUsageQueryKey(threadIdRef.current),
@@ -305,7 +309,11 @@ export function useThreadStream({
},
onFinish(state) {
listeners.current.onFinish?.(state.values);
pendingUsageBaselineMessageIdsRef.current = new Set();
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
void queryClient.invalidateQueries({ queryKey: ["threads", "search"] });
if (threadIdRef.current && !isMock) {
void queryClient.invalidateQueries({
@@ -339,7 +347,11 @@ export function useThreadStream({
useEffect(() => {
startedRef.current = false;
sendInFlightRef.current = false;
pendingUsageBaselineMessageIdsRef.current = new Set();
pendingUsageBaselineMessageIdsRef.current = new Set(
messagesRef.current
.map(messageIdentity)
.filter((id): id is string => Boolean(id)),
);
prevHumanMsgCountRef.current =
latestMessageCountsRef.current.humanMessageCount;
}, [threadId]);