Compare commits

..

10 Commits

82 changed files with 4651 additions and 1790 deletions
+11 -291
View File
@@ -3,10 +3,8 @@
from __future__ import annotations
import asyncio
import json
import logging
import threading
from pathlib import Path
from typing import Any
from app.channels.base import Channel
@@ -23,12 +21,6 @@ class DiscordChannel(Channel):
Configuration keys (in ``config.yaml`` under ``channels.discord``):
- ``bot_token``: Discord Bot token.
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
- ``mention_only``: (optional) If true, only respond when the bot is mentioned.
- ``allowed_channels``: (optional) List of channel IDs where messages are always accepted
(even when mention_only is true). Use for channels where you want the bot to respond
without mentions. Empty = mention_only applies everywhere.
- ``thread_mode``: (optional) If true, group a channel conversation into a thread.
Default: same as ``mention_only``.
"""
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
@@ -40,29 +32,6 @@ class DiscordChannel(Channel):
self._allowed_guilds.add(int(guild_id))
except (TypeError, ValueError):
continue
self._mention_only: bool = bool(config.get("mention_only", False))
self._thread_mode: bool = config.get("thread_mode", self._mention_only)
self._allowed_channels: set[str] = set()
for channel_id in config.get("allowed_channels", []):
self._allowed_channels.add(str(channel_id))
# Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON).
# Uses a dedicated JSON file separate from ChannelStore, which maps IM
# conversations to DeerFlow thread IDs — a different concern.
self._active_threads: dict[str, str] = {}
# Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()).
self._active_thread_ids: set[str] = set()
# Lock protecting _active_threads and the JSON file from concurrent access.
# _run_client (Discord loop thread) and the main thread both read/write.
self._thread_store_lock = threading.Lock()
store = config.get("channel_store")
if store is not None:
self._thread_store_path = store._path.parent / "discord_threads.json"
else:
self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json"
# Typing indicator management
self._typing_tasks: dict[str, asyncio.Task] = {}
self._client = None
self._thread: threading.Thread | None = None
@@ -106,56 +75,12 @@ class DiscordChannel(Channel):
self._thread = threading.Thread(target=self._run_client, daemon=True)
self._thread.start()
self._load_active_threads()
logger.info("Discord channel started")
def _load_active_threads(self) -> None:
"""Restore Discord thread mappings from the dedicated JSON file on startup."""
with self._thread_store_lock:
try:
if not self._thread_store_path.exists():
logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path)
return
data = json.loads(self._thread_store_path.read_text())
self._active_threads.clear()
self._active_thread_ids.clear()
for channel_id, thread_id in data.items():
self._active_threads[channel_id] = thread_id
self._active_thread_ids.add(thread_id)
if self._active_threads:
logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path)
except Exception:
logger.exception("[Discord] failed to load thread mappings")
def _save_thread(self, channel_id: str, thread_id: str) -> None:
"""Persist a Discord thread mapping to the dedicated JSON file."""
with self._thread_store_lock:
try:
data: dict[str, str] = {}
if self._thread_store_path.exists():
data = json.loads(self._thread_store_path.read_text())
old_id = data.get(channel_id)
data[channel_id] = thread_id
# Update reverse-lookup set
if old_id:
self._active_thread_ids.discard(old_id)
self._active_thread_ids.add(thread_id)
self._thread_store_path.parent.mkdir(parents=True, exist_ok=True)
self._thread_store_path.write_text(json.dumps(data, indent=2))
except Exception:
logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id)
async def stop(self) -> None:
self._running = False
self.bus.unsubscribe_outbound(self._on_outbound)
# Cancel all active typing indicator tasks
for target_id, task in list(self._typing_tasks.items()):
if not task.done():
task.cancel()
logger.debug("[Discord] cancelled typing task for target %s", target_id)
self._typing_tasks.clear()
if self._client and self._discord_loop and self._discord_loop.is_running():
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
try:
@@ -175,10 +100,6 @@ class DiscordChannel(Channel):
logger.info("Discord channel stopped")
async def send(self, msg: OutboundMessage) -> None:
# Stop typing indicator once we're sending the response
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
await asyncio.wrap_future(stop_future)
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
@@ -190,9 +111,6 @@ class DiscordChannel(Channel):
await asyncio.wrap_future(send_future)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
await asyncio.wrap_future(stop_future)
target = await self._resolve_target(msg)
if target is None:
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
@@ -212,41 +130,6 @@ class DiscordChannel(Channel):
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
return False
async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None:
"""Starts a loop to send periodic typing indicators."""
target_id = thread_ts or chat_id
if target_id in self._typing_tasks:
return # Already typing for this target
async def _typing_loop():
try:
while True:
try:
await channel.trigger_typing()
except Exception:
pass
await asyncio.sleep(10)
except asyncio.CancelledError:
pass
task = asyncio.create_task(_typing_loop())
self._typing_tasks[target_id] = task
async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None:
"""Stops the typing loop for a specific target."""
target_id = thread_ts or chat_id
task = self._typing_tasks.pop(target_id, None)
if task and not task.done():
task.cancel()
logger.debug("[Discord] stopped typing indicator for target %s", target_id)
async def _add_reaction(self, message) -> None:
"""Add a checkmark reaction to acknowledge the message was received."""
try:
await message.add_reaction("")
except Exception:
logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True)
async def _on_message(self, message) -> None:
if not self._running or not self._client:
return
@@ -269,143 +152,15 @@ class DiscordChannel(Channel):
if self._discord_module is None:
return
# Determine whether the bot is mentioned in this message
user = self._client.user if self._client else None
if user:
bot_mention = user.mention # <@ID>
alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant)
standard_mention = f"<@{user.id}>"
else:
bot_mention = None
alt_mention = None
standard_mention = ""
has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content)
# Strip mention from text for processing
if has_mention:
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
# Don't return early if text is empty — still process the mention (e.g., create thread)
# --- Determine thread/channel routing and typing target ---
thread_id = None
chat_id = None
typing_target = None # The Discord object to type into
if isinstance(message.channel, self._discord_module.Thread):
# --- Message already inside a thread ---
thread_obj = message.channel
thread_id = str(thread_obj.id)
chat_id = str(thread_obj.parent_id or thread_obj.id)
typing_target = thread_obj
# If this is a known active thread, process normally
if thread_id in self._active_thread_ids:
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
chat_id=chat_id,
user_id=str(message.author.id),
text=text,
msg_type=msg_type,
thread_ts=thread_id,
metadata={
"guild_id": str(guild.id) if guild else None,
"channel_id": str(message.channel.id),
"message_id": str(message.id),
},
)
inbound.topic_id = thread_id
self._publish(inbound)
# Start typing indicator in the thread
if typing_target:
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
asyncio.create_task(self._add_reaction(message))
return
# Thread not tracked (orphaned) — create new thread and handle below
logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id)
thread_id = None
typing_target = None
# At this point we're guaranteed to be in a channel, not a thread
# (the Thread case is handled above). Apply mention_only for all
# non-thread messages — no special case needed.
channel_id = str(message.channel.id)
# Check if there's an active thread for this channel
if channel_id in self._active_threads:
# respect mention_only: if enabled, only process messages that mention the bot
# (unless the channel is in allowed_channels)
# Messages within a thread are always allowed through (continuation).
# At this code point we know the message is in a channel, not a thread
# (Thread case handled above), so always apply the check.
if self._mention_only and not has_mention and channel_id not in self._allowed_channels:
logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id)
return
# mention_only + fresh @ → create new thread instead of routing to existing one
if self._mention_only and has_mention:
thread_obj = await self._create_thread(message)
if thread_obj is not None:
target_thread_id = str(thread_obj.id)
self._active_threads[channel_id] = target_thread_id
self._save_thread(channel_id, target_thread_id)
thread_id = target_thread_id
chat_id = channel_id
typing_target = thread_obj
logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id)
else:
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
thread_id = channel_id
chat_id = channel_id
typing_target = message.channel
else:
# Existing session → route to the existing thread
target_thread_id = self._active_threads[channel_id]
logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id)
thread_id = target_thread_id
chat_id = channel_id
typing_target = await self._get_channel_or_thread(target_thread_id)
elif self._mention_only and not has_mention and channel_id not in self._allowed_channels:
# Not mentioned and not in an allowed channel → skip
logger.debug("[Discord] skipping message without mention in channel %s", channel_id)
return
elif self._mention_only and has_mention:
# First mention in this channel → create thread
thread_obj = await self._create_thread(message)
if thread_obj is not None:
target_thread_id = str(thread_obj.id)
self._active_threads[channel_id] = target_thread_id
self._save_thread(channel_id, target_thread_id)
thread_id = target_thread_id
chat_id = channel_id
typing_target = thread_obj # Type into the new thread
logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name)
else:
# Fallback: thread creation failed (disabled/permissions), reply in channel
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
thread_id = channel_id
chat_id = channel_id
typing_target = message.channel # Type into the channel
elif self._thread_mode:
# thread_mode but mention_only is False → create thread anyway for conversation grouping
thread_obj = await self._create_thread(message)
if thread_obj is None:
# Thread creation failed (disabled/permissions), fall back to channel replies
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
thread_id = channel_id
chat_id = channel_id
typing_target = message.channel # Type into the channel
else:
target_thread_id = str(thread_obj.id)
self._active_threads[channel_id] = target_thread_id
self._save_thread(channel_id, target_thread_id)
thread_id = target_thread_id
chat_id = channel_id
typing_target = thread_obj # Type into the new thread
chat_id = str(message.channel.parent_id or message.channel.id)
thread_id = str(message.channel.id)
else:
# No threading — reply directly in channel
thread_id = channel_id
chat_id = channel_id
typing_target = message.channel # Type into the channel
thread = await self._create_thread(message)
if thread is None:
return
chat_id = str(message.channel.id)
thread_id = str(thread.id)
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
inbound = self._make_inbound(
@@ -422,15 +177,6 @@ class DiscordChannel(Channel):
)
inbound.topic_id = thread_id
# Start typing indicator in the correct target (thread or channel)
if typing_target:
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
self._publish(inbound)
asyncio.create_task(self._add_reaction(message))
def _publish(self, inbound) -> None:
"""Publish an inbound message to the main event loop."""
if self._main_loop and self._main_loop.is_running():
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
@@ -452,40 +198,14 @@ class DiscordChannel(Channel):
async def _create_thread(self, message):
try:
if self._discord_module is None:
return None
# Only TextChannel (type 0) and NewsChannel (type 10) support threads
channel_type = message.channel.type
if channel_type not in (
self._discord_module.ChannelType.text,
self._discord_module.ChannelType.news,
):
logger.info(
"[Discord] channel type %s (%s) does not support threads",
channel_type.value,
channel_type.name,
)
return None
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
return await message.create_thread(name=thread_name)
except self._discord_module.errors.HTTPException as exc:
if exc.code == 50024:
logger.info(
"[Discord] cannot create thread in channel %s (error code 50024): %s",
message.channel.id,
channel_type.name if (channel_type := message.channel.type) else "unknown",
)
else:
logger.exception(
"[Discord] failed to create thread for message=%s (HTTPException %s)",
message.id,
exc.code,
)
return None
except Exception:
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
try:
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
except Exception:
pass
return None
async def _resolve_target(self, msg: OutboundMessage):
+7 -16
View File
@@ -787,22 +787,13 @@ class ChannelManager:
return
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
try:
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [{"role": "human", "content": msg.text}]},
config=run_config,
context=run_context,
multitask_strategy="reject",
)
except Exception as exc:
if _is_thread_busy_error(exc):
logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id)
await self._send_error(msg, THREAD_BUSY_MESSAGE)
return
else:
raise
result = await client.runs.wait(
thread_id,
assistant_id,
input={"messages": [{"role": "human", "content": msg.text}]},
config=run_config,
context=run_context,
)
response_text = _extract_response_text(result)
artifacts = _extract_artifacts(result)
-2
View File
@@ -167,8 +167,6 @@ class ChannelService:
return False
try:
config = dict(config)
config["channel_store"] = self.store
channel = channel_cls(bus=self.bus, config=config)
self._channels[name] = channel
await channel.start()
+3 -31
View File
@@ -8,8 +8,6 @@ from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
_SECRET_FILE = ".jwt_secret"
class AuthConfig(BaseModel):
"""JWT and auth-related configuration. Parsed once at startup.
@@ -32,32 +30,6 @@ class AuthConfig(BaseModel):
_auth_config: AuthConfig | None = None
def _load_or_create_secret() -> str:
"""Load persisted JWT secret from ``{base_dir}/.jwt_secret``, or generate and persist a new one."""
from deerflow.config.paths import get_paths
paths = get_paths()
secret_file = paths.base_dir / _SECRET_FILE
try:
if secret_file.exists():
secret = secret_file.read_text(encoding="utf-8").strip()
if secret:
return secret
except OSError as exc:
raise RuntimeError(f"Failed to read JWT secret from {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can read its persisted auth secret.") from exc
secret = secrets.token_urlsafe(32)
try:
secret_file.parent.mkdir(parents=True, exist_ok=True)
fd = os.open(secret_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
with os.fdopen(fd, "w", encoding="utf-8") as fh:
fh.write(secret)
except OSError as exc:
raise RuntimeError(f"Failed to persist JWT secret to {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can store a stable auth secret.") from exc
return secret
def get_auth_config() -> AuthConfig:
"""Get the global AuthConfig instance. Parses from env on first call."""
global _auth_config
@@ -67,11 +39,11 @@ def get_auth_config() -> AuthConfig:
load_dotenv()
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
if not jwt_secret:
jwt_secret = _load_or_create_secret()
jwt_secret = secrets.token_urlsafe(32)
os.environ["AUTH_JWT_SECRET"] = jwt_secret
logger.warning(
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated secret "
"persisted to .jwt_secret. Sessions will survive restarts. "
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
"Sessions will be invalidated on restart. "
"For production, add AUTH_JWT_SECRET to your .env file: "
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
)
+5 -24
View File
@@ -20,9 +20,6 @@ ACTIVE_CONTENT_MIME_TYPES = {
"image/svg+xml",
}
MAX_SKILL_ARCHIVE_MEMBER_BYTES = 16 * 1024 * 1024
_SKILL_ARCHIVE_READ_CHUNK_SIZE = 64 * 1024
def _build_content_disposition(disposition_type: str, filename: str) -> str:
"""Build an RFC 5987 encoded Content-Disposition header value."""
@@ -47,22 +44,6 @@ def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
return False
def _read_skill_archive_member(zip_ref: zipfile.ZipFile, info: zipfile.ZipInfo) -> bytes:
"""Read a .skill archive member while enforcing an uncompressed size cap."""
if info.file_size > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
chunks: list[bytes] = []
total_read = 0
with zip_ref.open(info, "r") as src:
while chunk := src.read(_SKILL_ARCHIVE_READ_CHUNK_SIZE):
total_read += len(chunk)
if total_read > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
chunks.append(chunk)
return b"".join(chunks)
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
"""Extract a file from a .skill ZIP archive.
@@ -79,16 +60,16 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
try:
with zipfile.ZipFile(zip_path, "r") as zip_ref:
# List all files in the archive
infos_by_name = {info.filename: info for info in zip_ref.infolist()}
namelist = zip_ref.namelist()
# Try direct path first
if internal_path in infos_by_name:
return _read_skill_archive_member(zip_ref, infos_by_name[internal_path])
if internal_path in namelist:
return zip_ref.read(internal_path)
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
for name, info in infos_by_name.items():
for name in namelist:
if name.endswith("/" + internal_path) or name == internal_path:
return _read_skill_archive_member(zip_ref, info)
return zip_ref.read(name)
# Not found
return None
+2 -2
View File
@@ -99,7 +99,7 @@ rm -f backend/.deer-flow/data/deerflow.db
| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory |
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成并持久化到 `.deer-flow/.jwt_secret`,重启后 session 保持 |
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效 |
### 生产环境建议
@@ -137,4 +137,4 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` |
| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin |
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
| 重启后需要重新登录 | `.jwt_secret` 文件被删除且 `.env` 未设置 `AUTH_JWT_SECRET` | 在 `.env` 中设置固定密钥 |
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
+401
View File
@@ -0,0 +1,401 @@
# Storage Package Design
## Background
DeerFlow currently has several persistence responsibilities spread across app, gateway, runtime, and legacy persistence modules. This makes the persistence boundary difficult to reason about and creates several migration risks:
- Routers and runtime services can accidentally depend on concrete persistence implementations instead of stable contracts.
- User/auth, run metadata, thread metadata, feedback, run events, and checkpointer setup are initialized through different paths.
- Some persistence behavior is duplicated between memory, SQLite, and PostgreSQL-oriented code paths.
- Incremental migration is hard because app-level code and storage-level code are coupled.
- Adding or validating another SQL backend requires touching app/runtime code instead of a storage-owned package.
The storage package is introduced to make application data persistence a package-level capability with explicit contracts, a clear boundary, and SQL backend compatibility.
## Goals
- Provide a standalone `packages/storage` package for durable application data.
- Support SQLite, PostgreSQL, and MySQL through a shared persistence construction flow.
- Keep LangGraph checkpointer initialization compatible with the same database backend.
- Expose repository contracts as the only package-level data access boundary.
- Let the app layer depend on app-owned adapters under `app.infra.storage`, not on storage DB implementation classes.
- Allow the app/gateway migration to happen in small steps without forcing a large rewrite.
## Non-Goals
- This design does not remove legacy persistence in the first PR.
- This design does not move routers directly onto storage package models.
- This design does not make app routers own SQLAlchemy sessions.
- Cron persistence is intentionally out of scope for the storage package foundation.
- Memory backend is not part of the durable storage package. Memory compatibility, if still needed by app runtime, belongs outside `packages/storage`.
## Storage Design Principles
### Package-Owned Durable Storage
`packages/storage` owns durable application data persistence. It defines:
- configuration shape for storage-backed persistence
- SQLAlchemy models
- repository contracts and DTOs
- SQL repository implementations
- persistence factory functions
- compatibility helpers for config-driven initialization
The package should be usable without importing `app.gateway`, routers, auth providers, or runtime-specific gateway objects.
### SQL Backend Compatibility
The package supports three SQL backends:
- SQLite for local/single-node deployments
- PostgreSQL for production multi-node deployments
- MySQL for deployments that standardize on MySQL
Backend-specific differences are handled inside the storage package:
- SQLAlchemy async engine URL construction
- LangGraph checkpointer connection-string compatibility
- JSON metadata filtering across SQLite/PostgreSQL/MySQL
- SQL dialect behavior around locking, aggregation, and JSON type semantics
### Unified Persistence Bundle
Storage initialization returns an `AppPersistence` bundle:
```python
@dataclass(slots=True)
class AppPersistence:
checkpointer: Checkpointer
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
setup: Callable[[], Awaitable[None]]
aclose: Callable[[], Awaitable[None]]
```
The app runtime can initialize persistence once, call `setup()`, and then inject:
- `checkpointer`
- `session_factory`
- repository adapters
This keeps checkpointer and application data aligned to the same backend without requiring routers to understand database configuration.
## Package Layout
```text
backend/packages/storage/
store/
config/
storage_config.py
app_config.py
persistence/
factory.py
types.py
base_model.py
json_compat.py
drivers/
sqlite.py
postgres.py
mysql.py
repositories/
contracts/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
models/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
db/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
factory.py
```
## Persistence Construction
The primary storage entrypoint is:
```python
from store.persistence import create_persistence_from_storage_config
persistence = await create_persistence_from_storage_config(storage_config)
await persistence.setup()
```
For app-level compatibility with existing database config shape:
```python
from store.persistence import create_persistence_from_database_config
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
```
Expected app startup flow:
```python
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
app.state.persistence = persistence
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
Expected app shutdown flow:
```python
await app.state.persistence.aclose()
```
## Repository Contract Design
Repository contracts are the storage package's public data access boundary. They live under `store.repositories.contracts` and are re-exported from `store.repositories`.
The key contract groups are:
- `UserRepositoryProtocol`
- `RunRepositoryProtocol`
- `ThreadMetaRepositoryProtocol`
- `FeedbackRepositoryProtocol`
- `RunEventRepositoryProtocol`
Each contract owns:
- input DTOs, such as `UserCreate`, `RunCreate`, `ThreadMetaCreate`
- output DTOs, such as `User`, `Run`, `ThreadMeta`
- repository protocol methods
- domain-specific exceptions when needed, such as `InvalidMetadataFilterError`
Repository construction is session-based:
```python
from store.repositories import build_run_repository
async with persistence.session_factory() as session:
repo = build_run_repository(session)
run = await repo.get_run(run_id)
```
This keeps transaction ownership explicit. The storage package does not hide commits or session lifecycle inside global singletons.
## App/Infra Calling Contract
The app layer should not call `store.repositories.db.*` directly. The intended app boundary is `app.infra.storage`.
`app.infra.storage` is responsible for:
- receiving `session_factory` from FastAPI runtime initialization
- owning session lifecycle for app-facing repository methods
- translating storage DTOs to app/gateway DTOs only when needed
- preserving the existing app-facing names during migration
- depending on storage repository protocols, not concrete DB classes
Expected adapter pattern:
```python
class StorageRunRepository(RunRepositoryProtocol):
def __init__(self, session_factory):
self._session_factory = session_factory
async def get_run(self, run_id: str):
async with self._session_factory() as session:
repo = build_run_repository(session)
return await repo.get_run(run_id)
```
For gateway compatibility, app state can keep existing names while the implementation changes:
```python
app.state.run_store = StorageRunStore(run_repository)
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
app.state.run_event_store = StorageRunEventStore(run_event_repository)
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
The app-facing objects may expose legacy method names during migration, but their internal data access should go through storage contracts.
## Boundary Rules
### Allowed Calls
Storage package callers may use:
```python
from store.persistence import create_persistence_from_database_config
from store.persistence import create_persistence_from_storage_config
from store.repositories import build_run_repository
from store.repositories import build_user_repository
from store.repositories import build_thread_meta_repository
from store.repositories import build_feedback_repository
from store.repositories import build_run_event_repository
from store.repositories import RunRepositoryProtocol
from store.repositories import UserRepositoryProtocol
```
App layer callers should use:
```python
from app.infra.storage import StorageRunRepository
from app.infra.storage import StorageUserDataRepository
from app.infra.storage import StorageThreadMetaRepository
from app.infra.storage import StorageFeedbackRepository
from app.infra.storage import StorageRunEventRepository
```
### Prohibited Calls
App/gateway/router/auth code must not import:
```python
from store.repositories.db import DbRunRepository
from store.repositories.models import Run
from store.persistence.base_model import MappedBase
```
Routers must not:
- create SQLAlchemy engines
- create SQLAlchemy sessions directly
- call storage DB repository classes directly
- commit/rollback storage transactions directly unless explicitly scoped by an infra adapter
- depend on storage SQLAlchemy model classes
Storage package code must not import:
```python
import app.gateway
import app.infra
import deerflow.runtime
```
The dependency direction is:
```text
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
```
The reverse direction is forbidden.
## Checkpointer Compatibility
The storage persistence bundle initializes the LangGraph checkpointer alongside application data persistence.
Backend-specific notes:
- SQLite uses `langgraph-checkpoint-sqlite`.
- PostgreSQL uses `langgraph-checkpoint-postgres` and requires a string `postgresql://...` connection URL.
- MySQL uses `langgraph-checkpoint-mysql` and requires a string MySQL connection URL.
SQLAlchemy may use async driver URLs such as `postgresql+asyncpg://...` or `mysql+aiomysql://...`, but LangGraph checkpointer constructors expect plain string connection URLs. This conversion belongs inside the storage driver implementation.
## JSON Metadata Filtering
Thread metadata search supports dialect-aware JSON filtering through `store.persistence.json_compat`.
The matcher supports:
- `None`
- `bool`
- `int`
- `float`
- `str`
It rejects:
- unsafe keys
- nested JSON path expressions
- dict/list values
- integers outside signed 64-bit range
This prevents SQL/JSON path injection, avoids compiled-cache type drift, and preserves type semantics such as `True != 1` and explicit JSON `null` not matching a missing key.
## Step-by-Step Implementation Plan
### Step 1: Introduce Storage Package Foundation
- Add `backend/packages/storage`.
- Add storage config models.
- Add `AppPersistence`.
- Add SQLite/PostgreSQL/MySQL persistence drivers.
- Add repository contracts, models, DB implementations, and factory helpers.
- Add package dependency wiring.
- Exclude cron persistence.
### Step 2: Harden Storage Backend Compatibility
- Validate SQLite setup and repository behavior.
- Validate PostgreSQL and MySQL with local E2E tests.
- Fix checkpointer connection-string compatibility.
- Fix PostgreSQL locking and aggregation differences.
- Add dialect-aware JSON metadata filtering.
### Step 3: Add App Infra Adapters
- Add `backend/app/infra/storage`.
- Implement app-facing repositories that own session lifecycle.
- Keep storage contracts as the only data access boundary.
- Add legacy compatibility adapters for existing app/gateway method shapes.
- Keep app/gateway imports out of `packages/storage`.
### Step 4: Switch FastAPI Runtime Injection
- Initialize storage persistence in FastAPI startup/lifespan.
- Attach `persistence`, `checkpointer`, and `session_factory` to `app.state`.
- Preserve existing external state names:
- `run_store`
- `feedback_repo`
- `thread_store`
- `run_event_store`
- `checkpointer`
- `session_factory`
- Start with user/auth provider construction, then migrate run/thread/feedback/run_event.
### Step 5: Router and Auth Compatibility
- Ensure routers consume app-facing adapters, not storage DB classes.
- Ensure auth providers depend on user repository contracts.
- Keep router response shapes unchanged.
- Add focused auth/admin/router regression tests.
### Step 6: Cleanup Legacy Persistence
- Compare old persistence usage after app/gateway migration.
- Remove unused old repository implementations only after all call sites move.
- Keep compatibility shims only where needed for a transition window.
- Delete memory backend paths from storage-owned durable persistence.
## Testing Strategy
Unit tests should cover:
- config parsing
- persistence setup
- table creation
- repository CRUD/query behavior
- typed JSON metadata filtering
- dialect SQL compilation
- cron exclusion
E2E tests should cover:
- SQLite persistence setup
- PostgreSQL temporary database setup
- MySQL temporary database setup
- repository contract behavior across all supported SQL backends
- JSON/Unicode round trip
- rollback behavior
- persistence close/cleanup
E2E tests may remain local-only if CI does not provide PostgreSQL/MySQL services.
+401
View File
@@ -0,0 +1,401 @@
# Storage Package 设计文档
## 背景
DeerFlow 当前有多类持久化职责分散在 app、gateway、runtime 和旧 persistence 模块中。这会带来几个问题:
- routers 和 runtime services 容易依赖具体 persistence 实现,而不是稳定契约。
- user/auth、run metadata、thread metadata、feedback、run events、checkpointer setup 的初始化路径不统一。
- memory、SQLite、PostgreSQL 相关路径中存在部分重复逻辑。
- app 层代码和 storage 层代码耦合,导致增量迁移困难。
- 增加或验证新的 SQL backend 时,需要改动 app/runtime,而不是只改 storage package。
引入 storage package 的目标,是把应用数据持久化抽象成 package 级能力,并提供明确契约、清晰边界和 SQL backend 兼容性。
## 目标
- 新增独立的 `packages/storage`,负责 durable application data。
- 通过统一 persistence 构造流程支持 SQLite、PostgreSQL、MySQL。
- 保持 LangGraph checkpointer 与同一个数据库 backend 兼容。
- 将 repository contracts 作为 package 对外唯一数据访问边界。
- app 层通过 `app.infra.storage` 适配 storage,而不是直接依赖 storage DB 实现类。
- 支持 app/gateway 后续小步迁移,避免一次性大重构。
## 非目标
- 第一阶段不删除旧 persistence。
- 不让 routers 直接依赖 storage package models。
- 不让 app routers 管理 SQLAlchemy sessions。
- cron persistence 不属于 storage package 基础迁移范围。
- memory backend 不属于 durable storage package。若 app runtime 仍需要 memory 兼容,应放在 `packages/storage` 之外。
## Storage 设计理念
### Package 自己负责 Durable Storage
`packages/storage` 负责应用数据的 durable persistence,包括:
- storage 持久化配置
- SQLAlchemy models
- repository contracts 和 DTOs
- SQL repository 实现
- persistence factory functions
- 面向现有 config 的兼容初始化入口
该 package 不应该 import `app.gateway`、routers、auth providers 或 runtime 中的 gateway 对象。
### SQL Backend 兼容
该 package 支持三种 SQL backend
- SQLite:本地或单节点部署
- PostgreSQL:生产多节点部署
- MySQL:使用 MySQL 作为标准数据库的部署
backend 差异在 storage package 内部处理:
- SQLAlchemy async engine URL 构造
- LangGraph checkpointer 连接串兼容
- SQLite/PostgreSQL/MySQL 的 JSON metadata filter
- 不同 SQL 方言在 locking、aggregation、JSON 类型语义上的差异
### 统一 Persistence Bundle
Storage 初始化返回 `AppPersistence` bundle
```python
@dataclass(slots=True)
class AppPersistence:
checkpointer: Checkpointer
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
setup: Callable[[], Awaitable[None]]
aclose: Callable[[], Awaitable[None]]
```
app runtime 只需要初始化一次 persistence,调用 `setup()`,然后注入:
- `checkpointer`
- `session_factory`
- repository adapters
这样 checkpointer 和应用数据可以对齐到同一个 backend,同时 routers 不需要理解数据库配置。
## Package 结构
```text
backend/packages/storage/
store/
config/
storage_config.py
app_config.py
persistence/
factory.py
types.py
base_model.py
json_compat.py
drivers/
sqlite.py
postgres.py
mysql.py
repositories/
contracts/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
models/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
db/
user.py
run.py
thread_meta.py
feedback.py
run_event.py
factory.py
```
## Persistence 构造
storage 的主要入口:
```python
from store.persistence import create_persistence_from_storage_config
persistence = await create_persistence_from_storage_config(storage_config)
await persistence.setup()
```
为了兼容现有 app database config,也提供:
```python
from store.persistence import create_persistence_from_database_config
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
```
预期 app startup 流程:
```python
persistence = await create_persistence_from_database_config(config.database)
await persistence.setup()
app.state.persistence = persistence
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
预期 app shutdown 流程:
```python
await app.state.persistence.aclose()
```
## Repository 契约设计
Repository contracts 是 storage package 对外公开的数据访问边界。它们位于 `store.repositories.contracts`,并通过 `store.repositories` re-export。
主要契约包括:
- `UserRepositoryProtocol`
- `RunRepositoryProtocol`
- `ThreadMetaRepositoryProtocol`
- `FeedbackRepositoryProtocol`
- `RunEventRepositoryProtocol`
每组契约包含:
- 输入 DTO,例如 `UserCreate``RunCreate``ThreadMetaCreate`
- 输出 DTO,例如 `User``Run``ThreadMeta`
- repository protocol methods
- 必要的领域异常,例如 `InvalidMetadataFilterError`
Repository 通过 session 构造:
```python
from store.repositories import build_run_repository
async with persistence.session_factory() as session:
repo = build_run_repository(session)
run = await repo.get_run(run_id)
```
这样可以让 transaction ownership 保持明确。storage package 不通过全局 singleton 隐式隐藏 commit 或 session 生命周期。
## App/Infra 调用契约
app 层不应该直接调用 `store.repositories.db.*`。预期的 app 边界是 `app.infra.storage`
`app.infra.storage` 负责:
- 从 FastAPI runtime 初始化中接收 `session_factory`
- 为 app-facing repository methods 管理 session 生命周期
- 在必要时将 storage DTOs 转成 app/gateway DTOs
- 迁移期间保留现有 app-facing 名称
- 依赖 storage repository protocols,而不是具体 DB classes
预期 adapter 模式:
```python
class StorageRunRepository(RunRepositoryProtocol):
def __init__(self, session_factory):
self._session_factory = session_factory
async def get_run(self, run_id: str):
async with self._session_factory() as session:
repo = build_run_repository(session)
return await repo.get_run(run_id)
```
为了兼容 gatewayapp state 可以暂时保持现有名字,只替换内部实现:
```python
app.state.run_store = StorageRunStore(run_repository)
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
app.state.run_event_store = StorageRunEventStore(run_event_repository)
app.state.checkpointer = persistence.checkpointer
app.state.session_factory = persistence.session_factory
```
app-facing objects 可以在迁移期间保留旧方法名,但内部数据访问必须经过 storage contracts。
## 边界规则
### 允许调用的范围
storage package 调用方可以使用:
```python
from store.persistence import create_persistence_from_database_config
from store.persistence import create_persistence_from_storage_config
from store.repositories import build_run_repository
from store.repositories import build_user_repository
from store.repositories import build_thread_meta_repository
from store.repositories import build_feedback_repository
from store.repositories import build_run_event_repository
from store.repositories import RunRepositoryProtocol
from store.repositories import UserRepositoryProtocol
```
app 层应该使用:
```python
from app.infra.storage import StorageRunRepository
from app.infra.storage import StorageUserDataRepository
from app.infra.storage import StorageThreadMetaRepository
from app.infra.storage import StorageFeedbackRepository
from app.infra.storage import StorageRunEventRepository
```
### 禁止调用的范围
app/gateway/router/auth 代码不应该 import
```python
from store.repositories.db import DbRunRepository
from store.repositories.models import Run
from store.persistence.base_model import MappedBase
```
routers 禁止:
- 创建 SQLAlchemy engines
- 直接创建 SQLAlchemy sessions
- 直接调用 storage DB repository classes
- 直接 commit/rollback storage transactions,除非这是 infra adapter 明确管理的范围
- 依赖 storage SQLAlchemy model classes
storage package 禁止 import
```python
import app.gateway
import app.infra
import deerflow.runtime
```
依赖方向必须是:
```text
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
```
禁止反向依赖。
## Checkpointer 兼容
storage persistence bundle 会同时初始化 LangGraph checkpointer 和应用数据持久化。
backend 说明:
- SQLite 使用 `langgraph-checkpoint-sqlite`
- PostgreSQL 使用 `langgraph-checkpoint-postgres`,需要字符串形式的 `postgresql://...` 连接串。
- MySQL 使用 `langgraph-checkpoint-mysql`,需要字符串形式的 MySQL 连接串。
SQLAlchemy 可以使用 `postgresql+asyncpg://...``mysql+aiomysql://...` 这类 async driver URL,但 LangGraph checkpointer 构造函数需要普通字符串连接串。这个转换应该封装在 storage driver implementation 内部。
## JSON Metadata Filtering
Thread metadata search 通过 `store.persistence.json_compat` 支持跨方言 JSON filtering。
支持的 filter value 类型:
- `None`
- `bool`
- `int`
- `float`
- `str`
拒绝:
- unsafe keys
- nested JSON path expressions
- dict/list values
- 超出 signed 64-bit 范围的整数
这样可以避免 SQL/JSON path injection,避免 compiled-cache 类型漂移,并保留类型语义,例如 `True != 1`,显式 JSON `null` 不等于 missing key。
## 分步实现方案
### 第 1 步:新增 Storage Package 基础
- 新增 `backend/packages/storage`
- 增加 storage config models。
- 增加 `AppPersistence`
- 增加 SQLite/PostgreSQL/MySQL persistence drivers。
- 增加 repository contracts、models、DB implementations 和 factory helpers。
- 接入 package dependency。
- 排除 cron persistence。
### 第 2 步:补齐 Storage Backend 兼容性
- 验证 SQLite setup 和 repository 行为。
- 使用本地 E2E 验证 PostgreSQL 和 MySQL。
- 修复 checkpointer 连接串兼容。
- 修复 PostgreSQL locking 和 aggregation 差异。
- 增加跨方言 JSON metadata filtering。
### 第 3 步:新增 App Infra Adapters
- 新增 `backend/app/infra/storage`
- 实现 app-facing repositories,由它们管理 session 生命周期。
- 保持 storage contracts 作为唯一数据访问边界。
- 为现有 app/gateway method shape 增加兼容 adapters。
- 避免 `packages/storage` import app/gateway。
### 第 4 步:切换 FastAPI Runtime 注入
- 在 FastAPI startup/lifespan 中初始化 storage persistence。
-`persistence``checkpointer``session_factory` 注入 `app.state`
- 暂时保留现有对外 state 名称:
- `run_store`
- `feedback_repo`
- `thread_store`
- `run_event_store`
- `checkpointer`
- `session_factory`
- 先切 user/auth provider 构造,再逐步迁移 run/thread/feedback/run_event。
### 第 5 步:Router 和 Auth 兼容
- 确保 routers 消费 app-facing adapters,而不是 storage DB classes。
- 确保 auth providers 依赖 user repository contracts。
- 保持 router response shapes 不变。
- 增加 auth/admin/router regression tests。
### 第 6 步:清理旧 Persistence
- app/gateway 迁移完成后,再比较旧 persistence usage。
- 所有 call sites 迁移完成后,再删除未使用的旧 repository implementations。
- 只在必要时保留短期 compatibility shims。
- 从 storage-owned durable persistence 中移除 memory backend 路径。
## 测试策略
单测应覆盖:
- config parsing
- persistence setup
- table creation
- repository CRUD/query behavior
- typed JSON metadata filtering
- dialect SQL compilation
- cron exclusion
E2E 应覆盖:
- SQLite persistence setup
- PostgreSQL temporary database setup
- MySQL temporary database setup
- 所有支持 SQL backend 下的 repository contract 行为
- JSON/Unicode round trip
- rollback behavior
- persistence close/cleanup
如果 CI 暂时没有 PostgreSQL/MySQL servicesE2E 可以先作为 local-only 验证保留。
@@ -104,46 +104,45 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
return "[Tool call was interrupted and did not return a result.]"
def _build_patched_messages(self, messages: list) -> list | None:
"""Return messages with tool results grouped after their tool-call AIMessage.
"""Return a new message list with patches inserted at the correct positions.
This normalizes model-bound causal order before provider serialization while
preserving already-valid transcripts unchanged.
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
a synthetic ToolMessage is inserted immediately after that AIMessage.
Returns None if no patches are needed.
"""
tool_messages_by_id: dict[str, ToolMessage] = {}
# Collect IDs of all existing ToolMessages
existing_tool_msg_ids: set[str] = set()
for msg in messages:
if isinstance(msg, ToolMessage):
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
existing_tool_msg_ids.add(msg.tool_call_id)
tool_call_ids: set[str] = set()
# Check if any patching is needed
needs_patch = False
for msg in messages:
if getattr(msg, "type", None) != "ai":
continue
for tc in self._message_tool_calls(msg):
tc_id = tc.get("id")
if tc_id:
tool_call_ids.add(tc_id)
if tc_id and tc_id not in existing_tool_msg_ids:
needs_patch = True
break
if needs_patch:
break
if not needs_patch:
return None
# Build new list with patches inserted right after each dangling AIMessage
patched: list = []
consumed_tool_msg_ids: set[str] = set()
patched_ids: set[str] = set()
patch_count = 0
for msg in messages:
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
continue
patched.append(msg)
if getattr(msg, "type", None) != "ai":
continue
for tc in self._message_tool_calls(msg):
tc_id = tc.get("id")
if not tc_id or tc_id in consumed_tool_msg_ids:
continue
existing_tool_msg = tool_messages_by_id.get(tc_id)
if existing_tool_msg is not None:
patched.append(existing_tool_msg)
consumed_tool_msg_ids.add(tc_id)
else:
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
patched.append(
ToolMessage(
content=self._synthetic_tool_message_content(tc),
@@ -152,14 +151,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
status="error",
)
)
consumed_tool_msg_ids.add(tc_id)
patched_ids.add(tc_id)
patch_count += 1
if patched == messages:
return None
if patch_count:
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
return patched
@override
@@ -10,7 +10,6 @@ from typing import Any, Protocol, override, runtime_checkable
from langchain.agents import AgentState
from langchain.agents.middleware import SummarizationMiddleware
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
from langchain_core.messages.utils import get_buffer_string
from langgraph.config import get_config
from langgraph.graph.message import REMOVE_ALL_MESSAGES
from langgraph.runtime import Runtime
@@ -176,84 +175,12 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
]
}
@override
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary without emitting streaming events to the client.
Suppresses callbacks to prevent the internal summarization LLM call from
producing visible AI message chunks in the frontend's ``messages-tuple``
stream (issue #2804).
"""
if not messages_to_summarize:
return "No previous conversation history."
trimmed = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed:
return "Previous conversation was too long to summarize."
formatted = get_buffer_string(trimmed)
try:
response = self.model.with_config(callbacks=[]).invoke(
self.summary_prompt.format(messages=formatted).rstrip(),
config={
"metadata": {"lc_source": "summarization"},
"callbacks": [],
},
)
return self._extract_summary_text(response)
except Exception as e:
return f"Error generating summary: {e!s}"
@override
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
"""Generate summary without emitting streaming events to the client.
Suppresses callbacks to prevent the internal summarization LLM call from
producing visible AI message chunks in the frontend's ``messages-tuple``
stream (issue #2804).
"""
if not messages_to_summarize:
return "No previous conversation history."
trimmed = self._trim_messages_for_summary(messages_to_summarize)
if not trimmed:
return "Previous conversation was too long to summarize."
formatted = get_buffer_string(trimmed)
try:
response = await self.model.with_config(callbacks=[]).ainvoke(
self.summary_prompt.format(messages=formatted).rstrip(),
config={
"metadata": {"lc_source": "summarization"},
"callbacks": [],
},
)
return self._extract_summary_text(response)
except Exception as e:
return f"Error generating summary: {e!s}"
def _extract_summary_text(self, response: Any) -> str:
# Prefer .text which normalizes list content blocks (e.g. [{"type": "text", "text": "..."}]).
# Fall back to .content for non-LangChain responses.
summary_text = getattr(response, "text", None)
if summary_text is None:
summary_text = getattr(response, "content", "")
return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip()
@override
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
"""Override the base implementation to let the human message with the special name 'summary'.
And this message will be ignored to display in the frontend, but still can be used as context for the model.
"""
return [
HumanMessage(
content=f"Here is a summary of the conversation to date:\n\n{summary}",
name="summary",
additional_kwargs={"hide_from_ui": True},
)
]
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
def _preserve_dynamic_context_reminders(
self,
@@ -7,21 +7,17 @@ reminder message so the model still knows about the outstanding todo list.
Additionally, this middleware prevents the agent from exiting the loop while
there are still incomplete todo items. When the model produces a final response
(no tool calls) but todos are not yet complete, the middleware queues a reminder
for the next model request and jumps back to the model node to force continued
engagement. The completion reminder is injected via ``wrap_model_call`` instead
of being persisted into graph state as a normal user-visible message.
(no tool calls) but todos are not yet complete, the middleware injects a reminder
and jumps back to the model node to force continued engagement.
"""
from __future__ import annotations
import threading
from collections.abc import Awaitable, Callable
from typing import Any, override
from langchain.agents.middleware import TodoListMiddleware
from langchain.agents.middleware.todo import PlanningState, Todo
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
from langchain.agents.middleware.types import hook_config
from langchain_core.messages import AIMessage, HumanMessage
from langgraph.runtime import Runtime
@@ -59,51 +55,6 @@ def _format_todos(todos: list[Todo]) -> str:
return "\n".join(lines)
def _format_completion_reminder(todos: list[Todo]) -> str:
"""Format a completion reminder for incomplete todo items."""
incomplete = [t for t in todos if t.get("status") != "completed"]
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
return (
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
)
_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"}
def _has_tool_call_intent_or_error(message: AIMessage) -> bool:
"""Return True when an AIMessage is not a clean final answer.
Todo completion reminders should only fire when the model has produced a
plain final response. Provider/tool parsing details have moved across
LangChain versions and integrations, so keep all tool-intent/error signals
behind this helper instead of checking one concrete field at the call site.
"""
if message.tool_calls:
return True
if getattr(message, "invalid_tool_calls", None):
return True
# Backward/provider compatibility: some integrations preserve raw or legacy
# tool-call intent in additional_kwargs even when structured tool_calls is
# empty. If this helper changes, update the matching sentinel test
# `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`;
# if that test fails after a LangChain upgrade, review this helper so new
# tool-call/error fields are not silently treated as clean final answers.
additional_kwargs = getattr(message, "additional_kwargs", {}) or {}
if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"):
return True
response_metadata = getattr(message, "response_metadata", {}) or {}
return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS
class TodoMiddleware(TodoListMiddleware):
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
@@ -138,7 +89,6 @@ class TodoMiddleware(TodoListMiddleware):
formatted = _format_todos(todos)
reminder = HumanMessage(
name="todo_reminder",
additional_kwargs={"hide_from_ui": True},
content=(
"<system_reminder>\n"
"Your todo list from earlier is no longer visible in the current context window, "
@@ -163,100 +113,6 @@ class TodoMiddleware(TodoListMiddleware):
# Maximum number of completion reminders before allowing the agent to exit.
# This prevents infinite loops when the agent cannot make further progress.
_MAX_COMPLETION_REMINDERS = 2
# Hard cap for per-run reminder bookkeeping in long-lived middleware instances.
_MAX_COMPLETION_REMINDER_KEYS = 4096
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._lock = threading.Lock()
self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {}
self._completion_reminder_counts: dict[tuple[str, str], int] = {}
self._completion_reminder_touch_order: dict[tuple[str, str], int] = {}
self._completion_reminder_next_order = 0
@staticmethod
def _get_thread_id(runtime: Runtime) -> str:
context = getattr(runtime, "context", None)
thread_id = context.get("thread_id") if context else None
return str(thread_id) if thread_id else "default"
@staticmethod
def _get_run_id(runtime: Runtime) -> str:
context = getattr(runtime, "context", None)
run_id = context.get("run_id") if context else None
return str(run_id) if run_id else "default"
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
return self._get_thread_id(runtime), self._get_run_id(runtime)
def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
self._completion_reminder_next_order += 1
self._completion_reminder_touch_order[key] = self._completion_reminder_next_order
def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]:
keys = set(self._pending_completion_reminders)
keys.update(self._completion_reminder_counts)
keys.update(self._completion_reminder_touch_order)
return keys
def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
self._pending_completion_reminders.pop(key, None)
self._completion_reminder_counts.pop(key, None)
self._completion_reminder_touch_order.pop(key, None)
def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None:
keys = self._completion_reminder_keys_locked()
overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS
if overflow <= 0:
return
candidates = [key for key in keys if key != protected_key]
candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0))
for key in candidates[:overflow]:
self._drop_completion_reminder_key_locked(key)
def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None:
key = self._pending_key(runtime)
with self._lock:
self._pending_completion_reminders.setdefault(key, []).append(reminder)
self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1
self._touch_completion_reminder_key_locked(key)
self._prune_completion_reminder_state_locked(protected_key=key)
def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int:
key = self._pending_key(runtime)
with self._lock:
return self._completion_reminder_counts.get(key, 0)
def _drain_completion_reminders(self, runtime: Runtime) -> list[str]:
key = self._pending_key(runtime)
with self._lock:
reminders = self._pending_completion_reminders.pop(key, [])
if reminders or key in self._completion_reminder_counts:
self._touch_completion_reminder_key_locked(key)
return reminders
def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None:
thread_id, current_run_id = self._pending_key(runtime)
with self._lock:
for key in self._completion_reminder_keys_locked():
if key[0] == thread_id and key[1] != current_run_id:
self._drop_completion_reminder_key_locked(key)
def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None:
key = self._pending_key(runtime)
with self._lock:
self._drop_completion_reminder_key_locked(key)
@override
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@override
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_other_run_completion_reminders(runtime)
return None
@hook_config(can_jump_to=["model"])
@override
@@ -281,12 +137,10 @@ class TodoMiddleware(TodoListMiddleware):
if base_result is not None:
return base_result
# 2. Only intervene when the agent wants to exit cleanly. Tool-call
# intent or tool-call parse errors should be handled by the tool path
# instead of being masked by todo reminders.
# 2. Only intervene when the agent wants to exit (no tool calls).
messages = state.get("messages") or []
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
if not last_ai or _has_tool_call_intent_or_error(last_ai):
if not last_ai or last_ai.tool_calls:
return None
# 3. Allow exit when all todos are completed or there are no todos.
@@ -295,14 +149,24 @@ class TodoMiddleware(TodoListMiddleware):
return None
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS:
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
return None
# 5. Queue a reminder for the next model request and jump back. We must
# not persist this control prompt as a normal HumanMessage, otherwise it
# can leak into user-visible message streams and saved transcripts.
self._queue_completion_reminder(runtime, _format_completion_reminder(todos))
return {"jump_to": "model"}
# 5. Inject a reminder and force the agent back to the model.
incomplete = [t for t in todos if t.get("status") != "completed"]
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
reminder = HumanMessage(
name="todo_completion_reminder",
content=(
"<system_reminder>\n"
"You have incomplete todo items that must be finished before giving your final response:\n\n"
f"{incomplete_text}\n\n"
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
"as you finish them, and only respond when all items are done.\n"
"</system_reminder>"
),
)
return {"jump_to": "model", "messages": [reminder]}
@override
@hook_config(can_jump_to=["model"])
@@ -313,47 +177,3 @@ class TodoMiddleware(TodoListMiddleware):
) -> dict[str, Any] | None:
"""Async version of after_model."""
return self.after_model(state, runtime)
@staticmethod
def _format_pending_completion_reminders(reminders: list[str]) -> str:
return "\n\n".join(dict.fromkeys(reminders))
def _augment_request(self, request: ModelRequest) -> ModelRequest:
reminders = self._drain_completion_reminders(request.runtime)
if not reminders:
return request
new_messages = [
*request.messages,
HumanMessage(
content=self._format_pending_completion_reminders(reminders),
name="todo_completion_reminder",
additional_kwargs={"hide_from_ui": True},
),
]
return request.override(messages=new_messages)
@override
def wrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], ModelResponse],
) -> ModelCallResult:
return handler(self._augment_request(request))
@override
async def awrap_model_call(
self,
request: ModelRequest,
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
) -> ModelCallResult:
return await handler(self._augment_request(request))
@override
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None
@override
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
self._clear_current_run_completion_reminders(runtime)
return None
@@ -35,7 +35,7 @@ def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
if app_config is None:
try:
app_config = get_app_config()
except FileNotFoundError:
except (FileNotFoundError, ValueError):
return False
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
+35
View File
@@ -0,0 +1,35 @@
[project]
name = "deerflow-storage"
version = "0.1.0"
description = "DeerFlow storage framework"
requires-python = ">=3.12"
dependencies = [
"dotenv>=0.9.9",
"pydantic>=2.12.5",
"pyyaml>=6.0.3",
"sqlalchemy[asyncio]>=2.0,<3.0",
"alembic>=1.13",
"langgraph>=1.1.9",
]
[project.optional-dependencies]
postgres = [
"asyncpg>=0.29",
"langgraph-checkpoint-postgres>=3.0.5",
"psycopg[binary]>=3.3.3",
"psycopg-pool>=3.3.0",
]
mysql = [
"aiomysql>=0.2",
"langgraph-checkpoint-mysql>=3.0.0",
]
sqlite = [
"aiosqlite>=0.22.1",
"langgraph-checkpoint-sqlite>=3.0.3"
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.hatch.build.targets.wheel]
packages = ["store"]
@@ -0,0 +1,5 @@
from .enums import DataBaseType
__all__ = [
"DataBaseType",
]
@@ -0,0 +1,41 @@
from enum import Enum
from enum import IntEnum as SourceIntEnum
from enum import StrEnum as SourceStrEnum
from typing import Any, TypeVar
T = TypeVar("T", bound=Enum)
class _EnumBase:
"""Base enum class with common utility methods."""
@classmethod
def get_member_keys(cls) -> list[str]:
"""Return a list of enum member names."""
return list(cls.__members__.keys())
@classmethod
def get_member_values(cls) -> list:
"""Return a list of enum member values."""
return [item.value for item in cls.__members__.values()]
@classmethod
def get_member_dict(cls) -> dict[str, Any]:
"""Return a dict mapping member names to values."""
return {name: item.value for name, item in cls.__members__.items()}
class IntEnum(_EnumBase, SourceIntEnum):
"""Integer enum base class."""
class StrEnum(_EnumBase, SourceStrEnum):
"""String enum base class."""
class DataBaseType(StrEnum):
"""Database type."""
sqlite = "sqlite"
mysql = "mysql"
postgresql = "postgresql"
@@ -0,0 +1,286 @@
import logging
import os
from contextvars import ContextVar
from pathlib import Path
from typing import Any, Self
import yaml
from dotenv import load_dotenv
from pydantic import BaseModel, ConfigDict, Field
from store.config.storage_config import StorageConfig
load_dotenv()
logger = logging.getLogger(__name__)
def _default_config_candidates() -> tuple[Path, ...]:
"""Return deterministic config.yaml locations without relying on cwd."""
backend_dir = Path(__file__).resolve().parents[4]
repo_root = backend_dir.parent
cwd = Path.cwd().resolve()
candidates = (
cwd / "config.yaml",
backend_dir / "config.yaml",
repo_root / "config.yaml",
)
return tuple(dict.fromkeys(candidates))
def _storage_from_database_config(config_data: dict[str, Any]) -> None:
"""Keep the existing public `database:` config compatible with storage."""
if "storage" in config_data:
return
database = config_data.get("database")
if not isinstance(database, dict):
return
backend = database.get("backend")
if backend == "memory":
raise ValueError("database.backend='memory' is not supported by storage; handle memory mode before loading storage config")
storage: dict[str, Any] = {
"driver": "postgres" if backend == "postgres" else backend,
"sqlite_dir": database.get("sqlite_dir", ".deer-flow/data"),
"echo_sql": database.get("echo_sql", False),
"pool_size": database.get("pool_size", 5),
}
postgres_url = database.get("postgres_url")
if backend == "postgres" and isinstance(postgres_url, str) and postgres_url:
from sqlalchemy.engine.url import make_url
parsed = make_url(postgres_url)
storage["database_url"] = postgres_url
storage.update(
{
"username": parsed.username or "",
"password": parsed.password or "",
"host": parsed.host or "localhost",
"port": parsed.port or 5432,
"db_name": parsed.database or "deerflow",
}
)
config_data["storage"] = storage
class AppConfig(BaseModel):
"""DeerFlow application configuration."""
timezone: str = Field(default="UTC", description="Timezone for scheduling and timestamps (e.g. 'UTC', 'America/New_York')")
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
storage: StorageConfig = Field(default=StorageConfig())
model_config = ConfigDict(extra="allow", frozen=False)
@classmethod
def resolve_config_path(cls, config_path: str | None = None) -> Path:
"""Resolve the config file path.
Priority:
1. If provided `config_path` argument, use it.
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
"""
if config_path:
path = Path(config_path)
if not Path.exists(path):
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
return path
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
if not Path.exists(path):
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
return path
else:
for path in _default_config_candidates():
if path.exists():
return path
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
@classmethod
def from_file(cls, config_path: str | None = None) -> Self:
"""Load and validate config from YAML. See `resolve_config_path` for path resolution."""
resolved_path = cls.resolve_config_path(config_path)
with open(resolved_path, encoding="utf-8") as f:
config_data = yaml.safe_load(f) or {}
cls._check_config_version(config_data, resolved_path)
config_data = cls.resolve_env_variables(config_data)
_storage_from_database_config(config_data)
if os.getenv("TIMEZONE"):
config_data["timezone"] = os.getenv("TIMEZONE")
result = cls.model_validate(config_data)
return result
@classmethod
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
Emits a warning if the user's config_version is lower than the example's.
Missing config_version is treated as version 0 (pre-versioning).
"""
try:
user_version = int(config_data.get("config_version", 0))
except (TypeError, ValueError):
user_version = 0
# Find config.example.yaml by searching config.yaml's directory and its parents
example_path = None
search_dir = config_path.parent
for _ in range(5): # search up to 5 levels
candidate = search_dir / "config.example.yaml"
if candidate.exists():
example_path = candidate
break
parent = search_dir.parent
if parent == search_dir:
break
search_dir = parent
if example_path is None:
return
try:
with open(example_path, encoding="utf-8") as f:
example_data = yaml.safe_load(f)
raw = example_data.get("config_version", 0) if example_data else 0
try:
example_version = int(raw)
except (TypeError, ValueError):
example_version = 0
except Exception:
return
if user_version < example_version:
logger.warning(
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.",
user_version,
example_version,
)
@classmethod
def resolve_env_variables(cls, config: Any) -> Any:
"""Recursively replace $VAR strings with their environment variable values (e.g. $OPENAI_API_KEY)."""
if isinstance(config, str):
if config.startswith("$"):
env_value = os.getenv(config[1:])
if env_value is None:
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
return env_value
return config
elif isinstance(config, dict):
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
elif isinstance(config, list):
return [cls.resolve_env_variables(item) for item in config]
return config
_app_config: AppConfig | None = None
_app_config_path: Path | None = None
_app_config_mtime: float | None = None
_app_config_is_custom = False
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
def _get_config_mtime(config_path: Path) -> float | None:
"""Get the modification time of a config file if it exists."""
try:
return config_path.stat().st_mtime
except OSError:
return None
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
"""Load config from disk and refresh cache metadata."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
resolved_path = AppConfig.resolve_config_path(config_path)
_app_config = AppConfig.from_file(str(resolved_path))
_app_config_path = resolved_path
_app_config_mtime = _get_config_mtime(resolved_path)
_app_config_is_custom = False
return _app_config
def get_app_config() -> AppConfig:
"""Get the DeerFlow config instance.
Returns a cached singleton instance and automatically reloads it when the
underlying config file path or modification time changes. Use
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
the cache.
"""
global _app_config, _app_config_path, _app_config_mtime
runtime_override = _current_app_config.get()
if runtime_override is not None:
return runtime_override
if _app_config is not None and _app_config_is_custom:
return _app_config
resolved_path = AppConfig.resolve_config_path()
current_mtime = _get_config_mtime(resolved_path)
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
if should_reload:
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
logger.info(
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
_app_config_mtime,
current_mtime,
)
_load_and_cache_app_config(str(resolved_path))
return _app_config
def reload_app_config(config_path: str | None = None) -> AppConfig:
"""Force reload from file and update the cache."""
return _load_and_cache_app_config(config_path)
def reset_app_config() -> None:
"""Clear the cache so the next `get_app_config()` reloads from file."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = None
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = False
def set_app_config(config: AppConfig) -> None:
"""Inject a config instance directly, bypassing file loading (for testing)."""
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
_app_config = config
_app_config_path = None
_app_config_mtime = None
_app_config_is_custom = True
def peek_current_app_config() -> AppConfig | None:
"""Return the runtime-scoped AppConfig override, if one is active."""
return _current_app_config.get()
def push_current_app_config(config: AppConfig) -> None:
"""Push a runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
_current_app_config_stack.set(stack + (_current_app_config.get(),))
_current_app_config.set(config)
def pop_current_app_config() -> None:
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
stack = _current_app_config_stack.get()
if not stack:
_current_app_config.set(None)
return
previous = stack[-1]
_current_app_config_stack.set(stack[:-1])
_current_app_config.set(previous)
@@ -0,0 +1,69 @@
"""Unified storage backend configuration for checkpointer and application data.
SQLite: checkpointer → {sqlite_dir}/checkpoints.db, app → {sqlite_dir}/deerflow.db
(separate files to avoid write-lock contention)
Postgres: shared URL, independent connection pools per layer.
Sensitive values use $VAR syntax resolved by AppConfig.resolve_env_variables()
before this config is instantiated.
"""
from __future__ import annotations
import os
from typing import Literal
from pydantic import BaseModel, Field
def _strip_legacy_state_prefix(path: str) -> str:
"""Keep old .deer-flow/* config values compatible with Paths.base_dir."""
prefix = ".deer-flow/"
if path == ".deer-flow":
return "."
if path.startswith(prefix):
return path[len(prefix) :]
return path
class StorageConfig(BaseModel):
driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field(
default="sqlite",
description="Storage driver for both checkpointer and application data. 'sqlite' for single-node deployment (default),'postgres' for production multi-node deployment, 'mysql' for MySQL databases.",
)
sqlite_dir: str = Field(
default=".deer-flow/data",
description="Directory for SQLite .db files (sqlite driver only).",
)
username: str = Field(default="", description="db username ")
password: str = Field(default="", description="db password. Use $VAR syntax in config.yaml to read from .env.")
host: str = Field(default="localhost", description="db host.")
port: int = Field(default=5432, description="db port.")
db_name: str = Field(default="deerflow", description="db database name.")
database_url: str = Field(default="", description="Complete SQLAlchemy database URL. Takes precedence for non-SQLite drivers.")
sqlite_db_path: str = Field(default=".deer-flow/data", description="Directory for SQLite .db files (sqlite driver only).")
echo_sql: bool = Field(default=False, description="Log all SQL statements (debug only).")
pool_size: int = Field(default=5, description="Connection pool size per layer.")
# -- Derived helpers (not user-configured) --
@property
def _resolved_sqlite_dir(self) -> str:
"""Resolve sqlite_dir to an absolute path under DeerFlow's base dir."""
from pathlib import Path
path = Path(self.sqlite_dir)
if path.is_absolute():
return str(path.resolve())
try:
from deerflow.config.paths import resolve_path
return str(resolve_path(_strip_legacy_state_prefix(self.sqlite_dir)))
except ImportError:
return str(path.resolve())
@property
def sqlite_storage_path(self) -> str:
"""SQLite file path for storage-owned app data and checkpointer."""
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
@@ -0,0 +1,32 @@
from store.persistence.base_model import (
Base,
DataClassBase,
DateTimeMixin,
MappedBase,
TimeZone,
UniversalText,
id_key,
)
from .factory import (
create_persistence,
create_persistence_from_database_config,
create_persistence_from_storage_config,
storage_config_from_database_config,
)
from .types import AppPersistence
__all__ = [
"Base",
"DataClassBase",
"DateTimeMixin",
"MappedBase",
"TimeZone",
"UniversalText",
"id_key",
"create_persistence",
"create_persistence_from_database_config",
"create_persistence_from_storage_config",
"storage_config_from_database_config",
"AppPersistence",
]
@@ -0,0 +1,111 @@
from datetime import datetime
from typing import Annotated
from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator
from sqlalchemy.dialects.mysql import LONGTEXT
from sqlalchemy.ext.asyncio import AsyncAttrs
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
from store.utils import get_timezone
def current_time() -> datetime:
return get_timezone().now()
id_key = Annotated[
int,
mapped_column(
BigInteger().with_variant(Integer, "sqlite"),
primary_key=True,
unique=True,
index=True,
autoincrement=True,
sort_order=-999,
comment="Primary key ID",
),
]
class UniversalText(TypeDecorator[str]):
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
impl = Text
cache_ok = True
def load_dialect_impl(self, dialect): # noqa: ANN001
if dialect.name == "mysql":
return dialect.type_descriptor(LONGTEXT())
return dialect.type_descriptor(Text())
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
return value
def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001
return value
class TimeZone(TypeDecorator[datetime]):
"""Timezone-aware datetime type compatible with PostgreSQL and MySQL."""
impl = DateTime(timezone=True)
cache_ok = True
@property
def python_type(self) -> type[datetime]:
return datetime
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.utcoffset() != timezone.now().utcoffset():
value = timezone.from_datetime(value)
return value
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
timezone = get_timezone()
if value is not None and value.tzinfo is None:
value = value.replace(tzinfo=timezone.tz_info)
return value
class DateTimeMixin(MappedAsDataclass):
"""Mixin that adds created_time / updated_time columns."""
created_time: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
TimeZone,
init=False,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
class MappedBase(AsyncAttrs, DeclarativeBase):
"""Async-capable declarative base for all ORM models."""
@declared_attr.directive
def __tablename__(self) -> str:
return self.__name__.lower()
@declared_attr.directive
def __table_args__(self) -> dict:
return {"comment": self.__doc__ or ""}
class DataClassBase(MappedAsDataclass, MappedBase):
"""Declarative base with native dataclass integration."""
__abstract__ = True
class Base(DataClassBase, DateTimeMixin):
"""Declarative dataclass base with created_time / updated_time columns."""
__abstract__ = True
@@ -0,0 +1,9 @@
from .mysql import build_mysql_persistence
from .postgres import build_postgres_persistence
from .sqlite import build_sqlite_persistence
__all__ = [
"build_postgres_persistence",
"build_mysql_persistence",
"build_sqlite_persistence",
]
@@ -0,0 +1,76 @@
from __future__ import annotations
import json
from sqlalchemy import URL
from sqlalchemy.engine import make_url
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.persistence.shared import close_in_order
from store.persistence.types import AppPersistence
def _validate_mysql_driver(db_url: URL) -> str:
url = make_url(db_url)
driver = url.get_driver_name()
if driver not in {"aiomysql", "asyncmy"}:
raise ValueError(f"MySQL persistence requires async SQLAlchemy driver (aiomysql/asyncmy), got: {driver!r}")
return driver
def _checkpoint_conn_string(db_url: URL) -> str:
return db_url.render_as_string(hide_password=False)
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
_validate_mysql_driver(db_url)
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
import store.repositories.models # noqa: F401
engine = create_async_engine(
db_url,
echo=echo,
future=True,
pool_pre_ping=True,
pool_size=pool_size,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
checkpointer = await saver_cm.__aenter__()
async def setup() -> None:
# 1. LangGraph checkpoint tables / migrations
await checkpointer.setup()
# 2. ORM business tables
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
async def _close_saver() -> None:
await saver_cm.__aexit__(None, None, None)
async def aclose() -> None:
await close_in_order(
engine.dispose,
_close_saver,
)
return AppPersistence(
checkpointer=checkpointer,
engine=engine,
session_factory=session_factory,
setup=setup,
aclose=aclose,
)
@@ -0,0 +1,64 @@
from __future__ import annotations
import json
from sqlalchemy import URL
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.persistence.shared import close_in_order
from store.persistence.types import AppPersistence
def _checkpoint_conn_string(db_url: URL) -> str:
return db_url.set(drivername="postgresql").render_as_string(hide_password=False)
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
import store.repositories.models # noqa: F401
engine = create_async_engine(
db_url,
echo=echo,
future=True,
pool_pre_ping=True,
pool_size=pool_size,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url))
checkpointer = await saver_cm.__aenter__()
async def setup() -> None:
# 1. LangGraph checkpoint tables / migrations
await checkpointer.setup()
# 2. ORM business tables
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
async def _close_saver() -> None:
await saver_cm.__aexit__(None, None, None)
async def aclose() -> None:
await close_in_order(
engine.dispose,
_close_saver,
)
return AppPersistence(
checkpointer=checkpointer,
engine=engine,
session_factory=session_factory,
setup=setup,
aclose=aclose,
)
@@ -0,0 +1,68 @@
from __future__ import annotations
import json
from sqlalchemy import URL, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.persistence import MappedBase
from store.persistence.shared import close_in_order
from store.persistence.types import AppPersistence
async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence:
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
import store.repositories.models # noqa: F401
engine = create_async_engine(
db_url,
echo=echo,
future=True,
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
)
@event.listens_for(engine.sync_engine, "connect")
def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001
cursor = dbapi_conn.cursor()
try:
cursor.execute("PRAGMA journal_mode=WAL;")
cursor.execute("PRAGMA synchronous=NORMAL;")
cursor.execute("PRAGMA foreign_keys=ON;")
finally:
cursor.close()
session_factory = async_sessionmaker(
bind=engine,
class_=AsyncSession,
expire_on_commit=False,
autoflush=False,
)
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
checkpointer = await saver_cm.__aenter__()
async def setup() -> None:
# 1. LangGraph checkpoint tables
await checkpointer.setup()
# 2. ORM business tables
async with engine.begin() as conn:
await conn.run_sync(MappedBase.metadata.create_all)
async def _close_saver() -> None:
await saver_cm.__aexit__(None, None, None)
async def aclose() -> None:
await close_in_order(
engine.dispose,
_close_saver,
)
return AppPersistence(
checkpointer=checkpointer,
engine=engine,
session_factory=session_factory,
setup=setup,
aclose=aclose,
)
@@ -0,0 +1,123 @@
from typing import Any
from sqlalchemy import URL
from sqlalchemy.engine.url import make_url
from store.common import DataBaseType
from store.config.app_config import get_app_config
from store.config.storage_config import StorageConfig
from store.persistence.types import AppPersistence
def storage_config_from_database_config(database_config: Any) -> StorageConfig:
"""Convert the existing public DatabaseConfig shape to StorageConfig.
Storage only owns durable database-backed persistence. The app bridge
should handle memory mode before calling into this package.
"""
backend = getattr(database_config, "backend", None)
if backend == "sqlite":
return StorageConfig(
driver="sqlite",
sqlite_dir=getattr(database_config, "sqlite_dir", ".deer-flow/data"),
echo_sql=getattr(database_config, "echo_sql", False),
pool_size=getattr(database_config, "pool_size", 5),
)
if backend == "postgres":
postgres_url = getattr(database_config, "postgres_url", "")
if not postgres_url:
raise ValueError("database.postgres_url is required when database.backend is 'postgres'")
parsed = make_url(postgres_url)
return StorageConfig(
driver="postgres",
database_url=postgres_url,
username=parsed.username or "",
password=parsed.password or "",
host=parsed.host or "localhost",
port=parsed.port or 5432,
db_name=parsed.database or "deerflow",
echo_sql=getattr(database_config, "echo_sql", False),
pool_size=getattr(database_config, "pool_size", 5),
)
raise ValueError(f"Unsupported database backend for storage persistence: {backend!r}")
def _create_database_url(storage_config: StorageConfig) -> URL:
"""Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres)."""
if storage_config.driver == DataBaseType.sqlite:
driver = "sqlite+aiosqlite"
elif storage_config.driver == DataBaseType.mysql:
driver = "mysql+aiomysql"
elif storage_config.driver in (DataBaseType.postgresql, "postgres"):
driver = "postgresql+asyncpg"
else:
raise ValueError(f"Unsupported database driver: {storage_config.driver}")
if storage_config.driver == DataBaseType.sqlite:
import os
db_path = storage_config.sqlite_storage_path
os.makedirs(os.path.dirname(db_path), exist_ok=True)
url = URL.create(
drivername=driver,
database=db_path,
)
elif storage_config.database_url:
url = make_url(storage_config.database_url)
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
url = url.set(drivername="postgresql+asyncpg")
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
url = url.set(drivername="mysql+aiomysql")
else:
url = URL.create(
drivername=driver,
username=storage_config.username,
password=storage_config.password,
host=storage_config.host,
port=storage_config.port,
database=storage_config.db_name or "deerflow",
)
return url
async def create_persistence_from_storage_config(storage_config: StorageConfig) -> AppPersistence:
from .drivers.mysql import build_mysql_persistence
from .drivers.postgres import build_postgres_persistence
from .drivers.sqlite import build_sqlite_persistence
driver = storage_config.driver
db_url = _create_database_url(storage_config)
if driver in ("postgres", "postgresql"):
return await build_postgres_persistence(
db_url,
echo=storage_config.echo_sql,
pool_size=storage_config.pool_size,
)
if driver == "mysql":
return await build_mysql_persistence(
db_url,
echo=storage_config.echo_sql,
pool_size=storage_config.pool_size,
)
if driver == "sqlite":
return await build_sqlite_persistence(db_url, echo=storage_config.echo_sql)
raise ValueError(f"Unsupported database driver: {driver}")
async def create_persistence_from_database_config(database_config: Any) -> AppPersistence:
storage_config = storage_config_from_database_config(database_config)
return await create_persistence_from_storage_config(storage_config)
async def create_persistence() -> AppPersistence:
app_config = get_app_config()
return await create_persistence_from_storage_config(app_config.storage)
@@ -0,0 +1,189 @@
"""Dialect-aware JSON value matching for storage SQLAlchemy repositories."""
from __future__ import annotations
import re
from dataclasses import dataclass
from typing import Any
from sqlalchemy import BigInteger, Float, String, bindparam
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.compiler import SQLCompiler
from sqlalchemy.sql.expression import ColumnElement
from sqlalchemy.sql.visitors import InternalTraversal
from sqlalchemy.types import Boolean, TypeEngine
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
_INT64_MIN = -(2**63)
_INT64_MAX = 2**63 - 1
def validate_metadata_filter_key(key: object) -> bool:
"""Return True when *key* is safe for JSON metadata filter SQL paths."""
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
def validate_metadata_filter_value(value: object) -> bool:
"""Return True when *value* can be compiled into a portable JSON predicate."""
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
return False
if isinstance(value, int) and not isinstance(value, bool):
return _INT64_MIN <= value <= _INT64_MAX
return True
class JsonMatch(ColumnElement[bool]):
"""Dialect-portable ``column[key] == value`` for JSON columns."""
inherit_cache = True
type = Boolean()
_is_implicitly_boolean = True
_traverse_internals = [
("column", InternalTraversal.dp_clauseelement),
("key", InternalTraversal.dp_string),
("value", InternalTraversal.dp_plain_obj),
("value_type", InternalTraversal.dp_string),
]
def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None:
if not validate_metadata_filter_key(key):
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
if not validate_metadata_filter_value(value):
if isinstance(value, int) and not isinstance(value, bool):
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
self.column = column
self.key = key
self.value = value
self.value_type = type(value).__qualname__
super().__init__()
@dataclass(frozen=True)
class _Dialect:
null_type: str
num_types: tuple[str, ...]
num_cast: str
int_types: tuple[str, ...]
int_cast: str
int_guard: str | None
string_type: str
bool_type: str | None
true_value: str
false_value: str
_SQLITE = _Dialect(
null_type="null",
num_types=("integer", "real"),
num_cast="REAL",
int_types=("integer",),
int_cast="INTEGER",
int_guard=None,
string_type="text",
bool_type=None,
true_value="true",
false_value="false",
)
_POSTGRES = _Dialect(
null_type="null",
num_types=("number",),
num_cast="DOUBLE PRECISION",
int_types=("number",),
int_cast="BIGINT",
int_guard="'^-?[0-9]+$'",
string_type="string",
bool_type="boolean",
true_value="true",
false_value="false",
)
_MYSQL = _Dialect(
null_type="NULL",
num_types=("INTEGER", "DOUBLE", "DECIMAL"),
num_cast="DOUBLE",
int_types=("INTEGER",),
int_cast="SIGNED",
int_guard=None,
string_type="STRING",
bool_type="BOOLEAN",
true_value="true",
false_value="false",
)
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
param = bindparam(None, value, type_=sa_type)
return compiler.process(param, **kw)
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
if len(types) == 1:
return f"{typeof} = '{types[0]}'"
quoted = ", ".join(f"'{type_name}'" for type_name in types)
return f"{typeof} IN ({quoted})"
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
if value is None:
return f"{typeof} = '{dialect.null_type}'"
if isinstance(value, bool):
bool_str = dialect.true_value if value else dialect.false_value
if dialect.bool_type is None:
return f"{typeof} = '{bool_str}'"
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
if isinstance(value, int):
bp = _bind(compiler, value, BigInteger(), **kw)
if dialect.int_guard:
return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
if isinstance(value, float):
bp = _bind(compiler, value, Float(), **kw)
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
bp = _bind(compiler, str(value), String(), **kw)
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
@compiles(JsonMatch, "sqlite")
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
path = f'$."{element.key}"'
typeof = f"json_type({col}, '{path}')"
extract = f"json_extract({col}, '{path}')"
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
@compiles(JsonMatch, "postgresql")
def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
typeof = f"json_typeof({col} -> '{element.key}')"
extract = f"({col} ->> '{element.key}')"
return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw)
@compiles(JsonMatch, "mysql")
def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
if not validate_metadata_filter_key(element.key):
raise ValueError(f"Key escaped validation: {element.key!r}")
col = compiler.process(element.column, **kw)
path = f'$."{element.key}"'
typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))"
extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))"
return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw)
@compiles(JsonMatch)
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}")
def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch:
return JsonMatch(column, key, value)
@@ -0,0 +1,3 @@
from .close import close_in_order
__all__ = ["close_in_order"]
@@ -0,0 +1,28 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
AsyncCloser = Callable[[], Awaitable[None]]
async def close_in_order(*closers: AsyncCloser) -> None:
"""
Run async closers in order and raise the first error, if any.
Notes
-----
- Used to keep driver-specific close logic readable.
- We intentionally do not stop at first failure, so later resources
still get a chance to close.
"""
first_error: Exception | None = None
for closer in closers:
try:
await closer()
except Exception as exc:
if first_error is None:
first_error = exc
if first_error is not None:
raise first_error
@@ -0,0 +1,23 @@
from __future__ import annotations
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from langgraph.types import Checkpointer
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
AsyncSetup = Callable[[], Awaitable[None]]
AsyncClose = Callable[[], Awaitable[None]]
@dataclass(slots=True)
class AppPersistence:
"""
Unified runtime persistence bundle.
"""
checkpointer: Checkpointer
engine: AsyncEngine
session_factory: async_sessionmaker[AsyncSession]
setup: AsyncSetup
aclose: AsyncClose
@@ -0,0 +1,53 @@
from store.repositories.contracts import (
Feedback,
FeedbackAggregate,
FeedbackCreate,
FeedbackRepositoryProtocol,
InvalidMetadataFilterError,
Run,
RunCreate,
RunEvent,
RunEventCreate,
RunEventRepositoryProtocol,
RunRepositoryProtocol,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
User,
UserCreate,
UserNotFoundError,
UserRepositoryProtocol,
)
from store.repositories.factory import (
build_feedback_repository,
build_run_event_repository,
build_run_repository,
build_thread_meta_repository,
build_user_repository,
)
__all__ = [
"Feedback",
"FeedbackAggregate",
"FeedbackCreate",
"FeedbackRepositoryProtocol",
"InvalidMetadataFilterError",
"Run",
"RunCreate",
"RunEvent",
"RunEventCreate",
"RunEventRepositoryProtocol",
"RunRepositoryProtocol",
"ThreadMeta",
"ThreadMetaCreate",
"ThreadMetaRepositoryProtocol",
"User",
"UserCreate",
"UserNotFoundError",
"UserRepositoryProtocol",
"build_run_repository",
"build_run_event_repository",
"build_thread_meta_repository",
"build_feedback_repository",
"build_user_repository",
]
@@ -0,0 +1,49 @@
from store.repositories.contracts.feedback import (
Feedback,
FeedbackAggregate,
FeedbackCreate,
FeedbackRepositoryProtocol,
)
from store.repositories.contracts.run import (
Run,
RunCreate,
RunRepositoryProtocol,
)
from store.repositories.contracts.run_event import (
RunEvent,
RunEventCreate,
RunEventRepositoryProtocol,
)
from store.repositories.contracts.thread_meta import (
InvalidMetadataFilterError,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
)
from store.repositories.contracts.user import (
User,
UserCreate,
UserNotFoundError,
UserRepositoryProtocol,
)
__all__ = [
"Feedback",
"FeedbackAggregate",
"FeedbackCreate",
"FeedbackRepositoryProtocol",
"Run",
"RunCreate",
"RunEvent",
"RunEventCreate",
"RunEventRepositoryProtocol",
"RunRepositoryProtocol",
"InvalidMetadataFilterError",
"ThreadMeta",
"ThreadMetaCreate",
"ThreadMetaRepositoryProtocol",
"User",
"UserCreate",
"UserNotFoundError",
"UserRepositoryProtocol",
]
@@ -0,0 +1,77 @@
from __future__ import annotations
from datetime import datetime
from typing import Protocol, TypedDict
from pydantic import BaseModel, ConfigDict
class FeedbackCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
feedback_id: str
run_id: str
thread_id: str
rating: int
user_id: str | None = None
message_id: str | None = None
comment: str | None = None
class Feedback(BaseModel):
model_config = ConfigDict(frozen=True)
feedback_id: str
run_id: str
thread_id: str
rating: int
user_id: str | None
message_id: str | None
comment: str | None
created_time: datetime
class FeedbackAggregate(TypedDict):
run_id: str
total: int
positive: int
negative: int
class FeedbackRepositoryProtocol(Protocol):
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
pass
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
pass
async def get_feedback(self, feedback_id: str) -> Feedback | None:
pass
async def list_feedback_by_run(
self,
run_id: str,
*,
thread_id: str | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
pass
async def list_feedback_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
pass
async def delete_feedback(self, feedback_id: str) -> bool:
pass
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
pass
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
pass
@@ -0,0 +1,100 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Protocol
from pydantic import BaseModel, ConfigDict, Field
class RunCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
run_id: str
thread_id: str
assistant_id: str | None = None
user_id: str | None = None
status: str = "pending"
model_name: str | None = None
multitask_strategy: str = "reject"
error: str | None = None
follow_up_to_run_id: str | None = None
metadata: dict[str, Any] = Field(default_factory=dict)
kwargs: dict[str, Any] = Field(default_factory=dict)
created_time: datetime | None = None
class Run(BaseModel):
model_config = ConfigDict(frozen=True)
run_id: str
thread_id: str
assistant_id: str | None
user_id: str | None
status: str
model_name: str | None
multitask_strategy: str
error: str | None
follow_up_to_run_id: str | None
metadata: dict[str, Any]
kwargs: dict[str, Any]
total_input_tokens: int
total_output_tokens: int
total_tokens: int
llm_call_count: int
lead_agent_tokens: int
subagent_tokens: int
middleware_tokens: int
message_count: int
first_human_message: str | None
last_ai_message: str | None
created_time: datetime
updated_time: datetime | None
class RunRepositoryProtocol(Protocol):
async def create_run(self, data: RunCreate) -> Run:
pass
async def get_run(self, run_id: str) -> Run | None:
pass
async def list_runs_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[Run]:
pass
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
pass
async def delete_run(self, run_id: str) -> None:
pass
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
pass
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
first_human_message: str | None = None,
last_ai_message: str | None = None,
error: str | None = None,
) -> None:
pass
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
pass
@@ -0,0 +1,83 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Protocol
from pydantic import BaseModel, ConfigDict, Field
class RunEventCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str
run_id: str
user_id: str | None = None
event_type: str
category: str
content: Any = ""
metadata: dict[str, Any] = Field(default_factory=dict)
created_at: datetime | None = None
class RunEvent(BaseModel):
model_config = ConfigDict(frozen=True)
thread_id: str
run_id: str
user_id: str | None
event_type: str
category: str
content: Any
metadata: dict[str, Any]
seq: int
created_at: datetime
class RunEventRepositoryProtocol(Protocol):
# Sequence values are time-ordered integer cursors. The application layer
# owns the single-writer invariant for a thread while a run is active.
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
pass
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
pass
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
user_id: str | None = None,
) -> list[RunEvent]:
pass
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
pass
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
pass
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
pass
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
pass
@@ -0,0 +1,67 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Protocol
from pydantic import BaseModel, ConfigDict, Field
class InvalidMetadataFilterError(ValueError):
"""Raised when all client-supplied metadata filters are rejected."""
class ThreadMetaCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
thread_id: str
assistant_id: str | None = None
user_id: str | None = None
display_name: str | None = None
status: str = "idle"
metadata: dict[str, Any] = Field(default_factory=dict)
class ThreadMeta(BaseModel):
model_config = ConfigDict(frozen=True)
thread_id: str
assistant_id: str | None
user_id: str | None
display_name: str | None
status: str
metadata: dict[str, Any]
created_time: datetime
updated_time: datetime | None
class ThreadMetaRepositoryProtocol(Protocol):
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
pass
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
pass
async def update_thread_meta(
self,
thread_id: str,
*,
display_name: str | None = None,
status: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
pass
async def delete_thread(self, thread_id: str) -> None:
pass
async def search_threads(
self,
*,
metadata: dict[str, Any] | None = None,
status: str | None = None,
user_id: str | None = None,
assistant_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ThreadMeta]:
pass
@@ -0,0 +1,64 @@
from __future__ import annotations
from datetime import datetime
from typing import Literal, Protocol
from pydantic import BaseModel, ConfigDict
class UserNotFoundError(LookupError):
"""Raised when an update targets a user row that no longer exists."""
class UserCreate(BaseModel):
model_config = ConfigDict(extra="forbid")
id: str
email: str
password_hash: str | None = None
system_role: Literal["admin", "user"] = "user"
created_at: datetime | None = None
oauth_provider: str | None = None
oauth_id: str | None = None
needs_setup: bool = False
token_version: int = 0
class User(BaseModel):
model_config = ConfigDict(frozen=True)
id: str
email: str
password_hash: str | None
system_role: Literal["admin", "user"]
created_at: datetime
oauth_provider: str | None
oauth_id: str | None
needs_setup: bool
token_version: int
class UserRepositoryProtocol(Protocol):
async def create_user(self, data: UserCreate) -> User:
pass
async def get_user_by_id(self, user_id: str) -> User | None:
pass
async def get_user_by_email(self, email: str) -> User | None:
pass
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
pass
async def get_first_admin(self) -> User | None:
pass
async def update_user(self, data: User) -> User:
pass
async def count_users(self) -> int:
pass
async def count_admin_users(self) -> int:
pass
@@ -0,0 +1,13 @@
from store.repositories.db.feedback import DbFeedbackRepository
from store.repositories.db.run import DbRunRepository
from store.repositories.db.run_event import DbRunEventRepository
from store.repositories.db.thread_meta import DbThreadMetaRepository
from store.repositories.db.user import DbUserRepository
__all__ = [
"DbFeedbackRepository",
"DbRunRepository",
"DbRunEventRepository",
"DbThreadMetaRepository",
"DbUserRepository",
]
@@ -0,0 +1,142 @@
from __future__ import annotations
from datetime import UTC, datetime
from sqlalchemy import case, delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.feedback import Feedback, FeedbackAggregate, FeedbackCreate, FeedbackRepositoryProtocol
from store.repositories.models.feedback import Feedback as FeedbackModel
def _to_feedback(m: FeedbackModel) -> Feedback:
return Feedback(
feedback_id=m.feedback_id,
run_id=m.run_id,
thread_id=m.thread_id,
rating=m.rating,
user_id=m.user_id,
message_id=m.message_id,
comment=m.comment,
created_time=m.created_time,
)
class DbFeedbackRepository(FeedbackRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
if data.rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
model = FeedbackModel(
feedback_id=data.feedback_id,
run_id=data.run_id,
thread_id=data.thread_id,
rating=data.rating,
user_id=data.user_id,
message_id=data.message_id,
comment=data.comment,
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return _to_feedback(model)
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
if data.rating not in (1, -1):
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
result = await self._session.execute(
select(FeedbackModel).where(
FeedbackModel.thread_id == data.thread_id,
FeedbackModel.run_id == data.run_id,
FeedbackModel.user_id == data.user_id,
)
)
model = result.scalar_one_or_none()
if model is None:
return await self.create_feedback(data)
model.rating = data.rating
model.message_id = data.message_id
model.comment = data.comment
model.created_time = datetime.now(UTC)
await self._session.flush()
await self._session.refresh(model)
return _to_feedback(model)
async def get_feedback(self, feedback_id: str) -> Feedback | None:
result = await self._session.execute(select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
model = result.scalar_one_or_none()
return _to_feedback(model) if model else None
async def list_feedback_by_run(
self,
run_id: str,
*,
thread_id: str | None = None,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
stmt = select(FeedbackModel).where(FeedbackModel.run_id == run_id)
if thread_id is not None:
stmt = stmt.where(FeedbackModel.thread_id == thread_id)
if user_id is not None:
stmt = stmt.where(FeedbackModel.user_id == user_id)
stmt = stmt.order_by(FeedbackModel.created_time.desc())
if limit is not None:
stmt = stmt.limit(limit)
result = await self._session.execute(stmt)
return [_to_feedback(m) for m in result.scalars().all()]
async def list_feedback_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int | None = None,
) -> list[Feedback]:
stmt = select(FeedbackModel).where(FeedbackModel.thread_id == thread_id)
if user_id is not None:
stmt = stmt.where(FeedbackModel.user_id == user_id)
stmt = stmt.order_by(FeedbackModel.created_time.desc())
if limit is not None:
stmt = stmt.limit(limit)
result = await self._session.execute(stmt)
return [_to_feedback(m) for m in result.scalars().all()]
async def delete_feedback(self, feedback_id: str) -> bool:
existing = await self.get_feedback(feedback_id)
if existing is None:
return False
await self._session.execute(delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
return True
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
stmt = select(FeedbackModel).where(
FeedbackModel.thread_id == thread_id,
FeedbackModel.run_id == run_id,
)
if user_id is not None:
stmt = stmt.where(FeedbackModel.user_id == user_id)
result = await self._session.execute(stmt)
model = result.scalar_one_or_none()
if model is None:
return False
await self._session.delete(model)
return True
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
stmt = select(
func.count().label("total"),
func.coalesce(func.sum(case((FeedbackModel.rating == 1, 1), else_=0)), 0).label("positive"),
func.coalesce(func.sum(case((FeedbackModel.rating == -1, 1), else_=0)), 0).label("negative"),
).where(FeedbackModel.thread_id == thread_id, FeedbackModel.run_id == run_id)
row = (await self._session.execute(stmt)).one()
return {
"run_id": run_id,
"total": int(row.total),
"positive": int(row.positive),
"negative": int(row.negative),
}
@@ -0,0 +1,185 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import delete, func, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol
from store.repositories.models.run import Run as RunModel
def _to_run(m: RunModel) -> Run:
return Run(
run_id=m.run_id,
thread_id=m.thread_id,
assistant_id=m.assistant_id,
user_id=m.user_id,
status=m.status,
model_name=m.model_name,
multitask_strategy=m.multitask_strategy,
error=m.error,
follow_up_to_run_id=m.follow_up_to_run_id,
metadata=dict(m.meta or {}),
kwargs=dict(m.kwargs or {}),
total_input_tokens=m.total_input_tokens,
total_output_tokens=m.total_output_tokens,
total_tokens=m.total_tokens,
llm_call_count=m.llm_call_count,
lead_agent_tokens=m.lead_agent_tokens,
subagent_tokens=m.subagent_tokens,
middleware_tokens=m.middleware_tokens,
message_count=m.message_count,
first_human_message=m.first_human_message,
last_ai_message=m.last_ai_message,
created_time=m.created_time,
updated_time=m.updated_time,
)
class DbRunRepository(RunRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_run(self, data: RunCreate) -> Run:
model = RunModel(
run_id=data.run_id,
thread_id=data.thread_id,
assistant_id=data.assistant_id,
user_id=data.user_id,
status=data.status,
model_name=data.model_name,
multitask_strategy=data.multitask_strategy,
error=data.error,
follow_up_to_run_id=data.follow_up_to_run_id,
meta=dict(data.metadata),
kwargs=dict(data.kwargs),
)
if data.created_time is not None:
model.created_time = data.created_time
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return _to_run(model)
async def get_run(self, run_id: str) -> Run | None:
result = await self._session.execute(select(RunModel).where(RunModel.run_id == run_id))
model = result.scalar_one_or_none()
return _to_run(model) if model else None
async def list_runs_by_thread(
self,
thread_id: str,
*,
user_id: str | None = None,
limit: int = 50,
offset: int = 0,
) -> list[Run]:
stmt = select(RunModel).where(RunModel.thread_id == thread_id)
if user_id is not None:
stmt = stmt.where(RunModel.user_id == user_id)
stmt = stmt.order_by(RunModel.created_time.desc()).limit(limit).offset(offset)
result = await self._session.execute(stmt)
return [_to_run(m) for m in result.scalars().all()]
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
values: dict = {"status": status}
if error is not None:
values["error"] = error
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
async def delete_run(self, run_id: str) -> None:
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
if before is None:
before_dt = datetime.now().astimezone()
elif isinstance(before, datetime):
before_dt = before
else:
before_dt = datetime.fromisoformat(before)
result = await self._session.execute(select(RunModel).where(RunModel.status == "pending", RunModel.created_time <= before_dt).order_by(RunModel.created_time.asc()))
return [_to_run(m) for m in result.scalars().all()]
async def update_run_completion(
self,
run_id: str,
*,
status: str,
total_input_tokens: int = 0,
total_output_tokens: int = 0,
total_tokens: int = 0,
llm_call_count: int = 0,
lead_agent_tokens: int = 0,
subagent_tokens: int = 0,
middleware_tokens: int = 0,
message_count: int = 0,
first_human_message: str | None = None,
last_ai_message: str | None = None,
error: str | None = None,
) -> None:
values = {
"status": status,
"total_input_tokens": total_input_tokens,
"total_output_tokens": total_output_tokens,
"total_tokens": total_tokens,
"llm_call_count": llm_call_count,
"lead_agent_tokens": lead_agent_tokens,
"subagent_tokens": subagent_tokens,
"middleware_tokens": middleware_tokens,
"message_count": message_count,
}
if first_human_message is not None:
values["first_human_message"] = first_human_message[:2000]
if last_ai_message is not None:
values["last_ai_message"] = last_ai_message[:2000]
if error is not None:
values["error"] = error
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
completed = RunModel.status.in_(("success", "error"))
model_expr = func.coalesce(RunModel.model_name, "unknown")
stmt = (
select(
model_expr.label("model"),
func.count().label("runs"),
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
func.coalesce(func.sum(RunModel.total_output_tokens), 0).label("total_output_tokens"),
func.coalesce(func.sum(RunModel.lead_agent_tokens), 0).label("lead_agent"),
func.coalesce(func.sum(RunModel.subagent_tokens), 0).label("subagent"),
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
)
.where(RunModel.thread_id == thread_id, completed)
.group_by(model_expr)
)
rows = (await self._session.execute(stmt)).all()
total_tokens = total_input = total_output = total_runs = 0
lead_agent = subagent = middleware = 0
by_model: dict[str, dict] = {}
for row in rows:
by_model[row.model] = {"tokens": row.total_tokens, "runs": row.runs}
total_tokens += row.total_tokens
total_input += row.total_input_tokens
total_output += row.total_output_tokens
total_runs += row.runs
lead_agent += row.lead_agent
subagent += row.subagent
middleware += row.middleware
return {
"total_tokens": total_tokens,
"total_input_tokens": total_input,
"total_output_tokens": total_output,
"total_runs": total_runs,
"by_model": by_model,
"by_caller": {
"lead_agent": lead_agent,
"subagent": subagent,
"middleware": middleware,
},
}
@@ -0,0 +1,207 @@
from __future__ import annotations
import json
import secrets
import threading
import time
from typing import Any
from sqlalchemy import delete, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
from store.repositories.models.run_event import RunEvent as RunEventModel
_SEQ_COUNTER_BITS = 12
_SEQ_PROCESS_BITS = 9
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
class _SequenceAllocator:
def __init__(self) -> None:
self._last_millis = 0
self._lock = threading.Lock()
def allocate_base(self, batch_size: int) -> int:
if batch_size >= _SEQ_COUNTER_LIMIT:
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
now_ms = time.time_ns() // 1_000_000
with self._lock:
seq_ms = max(now_ms, self._last_millis + 1)
self._last_millis = seq_ms
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
_sequence_allocator = _SequenceAllocator()
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
if not isinstance(content, str):
next_metadata = {**metadata, "content_is_json": True}
if isinstance(content, dict):
next_metadata["content_is_dict"] = True
return json.dumps(content, default=str, ensure_ascii=False), next_metadata
return content, metadata
def _deserialize_content(content: str, metadata: dict[str, Any]) -> Any:
if not (metadata.get("content_is_json") or metadata.get("content_is_dict")):
return content
try:
return json.loads(content)
except json.JSONDecodeError:
return content
def _to_run_event(model: RunEventModel) -> RunEvent:
raw_metadata = dict(model.meta or {})
metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"}
return RunEvent(
thread_id=model.thread_id,
run_id=model.run_id,
user_id=model.user_id,
event_type=model.event_type,
category=model.category,
content=_deserialize_content(model.content, raw_metadata),
metadata=metadata,
seq=model.seq,
created_at=model.created_at,
)
class DbRunEventRepository(RunEventRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
if not events:
return []
seq_base = _sequence_allocator.allocate_base(len(events))
rows: list[RunEventModel] = []
for index, event in enumerate(events, start=1):
content, metadata = _serialize_content(event.content, dict(event.metadata))
row = RunEventModel(
thread_id=event.thread_id,
run_id=event.run_id,
user_id=event.user_id,
seq=seq_base + index,
event_type=event.event_type,
category=event.category,
content=content,
meta=metadata,
)
if event.created_at is not None:
row.created_at = event.created_at
self._session.add(row)
rows.append(row)
await self._session.flush()
return [_to_run_event(row) for row in rows]
async def list_messages(
self,
thread_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
stmt = select(RunEventModel).where(
RunEventModel.thread_id == thread_id,
RunEventModel.category == "message",
)
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
if before_seq is not None:
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
if after_seq is not None:
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
result = await self._session.execute(stmt)
return [_to_run_event(row) for row in result.scalars().all()]
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
async def list_events(
self,
thread_id: str,
run_id: str,
*,
event_types: list[str] | None = None,
limit: int = 500,
user_id: str | None = None,
) -> list[RunEvent]:
stmt = select(RunEventModel).where(
RunEventModel.thread_id == thread_id,
RunEventModel.run_id == run_id,
)
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
if event_types is not None:
stmt = stmt.where(RunEventModel.event_type.in_(event_types))
stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit)
result = await self._session.execute(stmt)
return [_to_run_event(row) for row in result.scalars().all()]
async def list_messages_by_run(
self,
thread_id: str,
run_id: str,
*,
limit: int = 50,
before_seq: int | None = None,
after_seq: int | None = None,
user_id: str | None = None,
) -> list[RunEvent]:
stmt = select(RunEventModel).where(
RunEventModel.thread_id == thread_id,
RunEventModel.run_id == run_id,
RunEventModel.category == "message",
)
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
if before_seq is not None:
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
if after_seq is not None:
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
result = await self._session.execute(stmt)
return [_to_run_event(row) for row in result.scalars().all()]
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
result = await self._session.execute(stmt)
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
stmt = select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
if user_id is not None:
stmt = stmt.where(RunEventModel.user_id == user_id)
count = await self._session.scalar(stmt)
return int(count or 0)
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
conditions = [RunEventModel.thread_id == thread_id]
if user_id is not None:
conditions.append(RunEventModel.user_id == user_id)
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
await self._session.execute(delete(RunEventModel).where(*conditions))
return int(count or 0)
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
conditions = [RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id]
if user_id is not None:
conditions.append(RunEventModel.user_id == user_id)
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
await self._session.execute(delete(RunEventModel).where(*conditions))
return int(count or 0)
@@ -0,0 +1,113 @@
from __future__ import annotations
import logging
from typing import Any
from sqlalchemy import delete, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from store.persistence.json_compat import json_match
from store.repositories.contracts.thread_meta import (
InvalidMetadataFilterError,
ThreadMeta,
ThreadMetaCreate,
ThreadMetaRepositoryProtocol,
)
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
logger = logging.getLogger(__name__)
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
return ThreadMeta(
thread_id=m.thread_id,
assistant_id=m.assistant_id,
user_id=m.user_id,
display_name=m.display_name,
status=m.status,
metadata=dict(m.meta or {}),
created_time=m.created_time,
updated_time=m.updated_time,
)
class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
model = ThreadMetaModel(
thread_id=data.thread_id,
assistant_id=data.assistant_id,
user_id=data.user_id,
display_name=data.display_name,
status=data.status,
meta=dict(data.metadata),
)
self._session.add(model)
await self._session.flush()
await self._session.refresh(model)
return _to_thread_meta(model)
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
model = result.scalar_one_or_none()
return _to_thread_meta(model) if model else None
async def update_thread_meta(
self,
thread_id: str,
*,
display_name: str | None = None,
status: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
values: dict = {}
if display_name is not None:
values["display_name"] = display_name
if status is not None:
values["status"] = status
if metadata is not None:
values["meta"] = dict(metadata)
if not values:
return
await self._session.execute(update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
async def delete_thread(self, thread_id: str) -> None:
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
async def search_threads(
self,
*,
metadata: dict[str, Any] | None = None,
status: str | None = None,
user_id: str | None = None,
assistant_id: str | None = None,
limit: int = 100,
offset: int = 0,
) -> list[ThreadMeta]:
stmt = select(ThreadMetaModel)
if status is not None:
stmt = stmt.where(ThreadMetaModel.status == status)
if user_id is not None:
stmt = stmt.where(ThreadMetaModel.user_id == user_id)
if assistant_id is not None:
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
if metadata:
applied = 0
for key, value in metadata.items():
try:
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
applied += 1
except (ValueError, TypeError) as exc:
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
if applied == 0:
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
stmt = stmt.limit(limit).offset(offset)
result = await self._session.execute(stmt)
return [_to_thread_meta(m) for m in result.scalars().all()]
@@ -0,0 +1,98 @@
from __future__ import annotations
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories.contracts.user import User, UserCreate, UserNotFoundError, UserRepositoryProtocol
from store.repositories.models.user import User as UserModel
def _to_user(model: UserModel) -> User:
return User(
id=model.id,
email=model.email,
password_hash=model.password_hash,
system_role=model.system_role, # type: ignore[arg-type]
created_at=model.created_at,
oauth_provider=model.oauth_provider,
oauth_id=model.oauth_id,
needs_setup=model.needs_setup,
token_version=model.token_version,
)
class DbUserRepository(UserRepositoryProtocol):
def __init__(self, session: AsyncSession) -> None:
self._session = session
async def create_user(self, data: UserCreate) -> User:
model = UserModel(
id=data.id,
email=data.email,
system_role=data.system_role,
password_hash=data.password_hash,
oauth_provider=data.oauth_provider,
oauth_id=data.oauth_id,
needs_setup=data.needs_setup,
token_version=data.token_version,
)
if data.created_at is not None:
model.created_at = data.created_at
self._session.add(model)
try:
await self._session.flush()
except IntegrityError as exc:
await self._session.rollback()
raise ValueError(f"Email already registered: {data.email}") from exc
await self._session.refresh(model)
return _to_user(model)
async def get_user_by_id(self, user_id: str) -> User | None:
model = await self._session.get(UserModel, user_id)
return _to_user(model) if model is not None else None
async def get_user_by_email(self, email: str) -> User | None:
result = await self._session.execute(select(UserModel).where(UserModel.email == email))
model = result.scalar_one_or_none()
return _to_user(model) if model is not None else None
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
result = await self._session.execute(
select(UserModel).where(
UserModel.oauth_provider == provider,
UserModel.oauth_id == oauth_id,
)
)
model = result.scalar_one_or_none()
return _to_user(model) if model is not None else None
async def get_first_admin(self) -> User | None:
result = await self._session.execute(select(UserModel).where(UserModel.system_role == "admin").limit(1))
model = result.scalar_one_or_none()
return _to_user(model) if model is not None else None
async def update_user(self, data: User) -> User:
model = await self._session.get(UserModel, data.id)
if model is None:
raise UserNotFoundError(f"User {data.id} no longer exists")
model.email = data.email
model.password_hash = data.password_hash
model.system_role = data.system_role
model.oauth_provider = data.oauth_provider
model.oauth_id = data.oauth_id
model.needs_setup = data.needs_setup
model.token_version = data.token_version
await self._session.flush()
await self._session.refresh(model)
return _to_user(model)
async def count_users(self) -> int:
count = await self._session.scalar(select(func.count()).select_from(UserModel))
return int(count or 0)
async def count_admin_users(self) -> int:
count = await self._session.scalar(select(func.count()).select_from(UserModel).where(UserModel.system_role == "admin"))
return int(count or 0)
@@ -0,0 +1,36 @@
from sqlalchemy.ext.asyncio import AsyncSession
from store.repositories import (
FeedbackRepositoryProtocol,
RunEventRepositoryProtocol,
RunRepositoryProtocol,
ThreadMetaRepositoryProtocol,
UserRepositoryProtocol,
)
from store.repositories.db import (
DbFeedbackRepository,
DbRunEventRepository,
DbRunRepository,
DbThreadMetaRepository,
DbUserRepository,
)
def build_thread_meta_repository(session: AsyncSession) -> ThreadMetaRepositoryProtocol:
return DbThreadMetaRepository(session)
def build_run_repository(session: AsyncSession) -> RunRepositoryProtocol:
return DbRunRepository(session)
def build_feedback_repository(session: AsyncSession) -> FeedbackRepositoryProtocol:
return DbFeedbackRepository(session)
def build_run_event_repository(session: AsyncSession) -> RunEventRepositoryProtocol:
return DbRunEventRepository(session)
def build_user_repository(session: AsyncSession) -> UserRepositoryProtocol:
return DbUserRepository(session)
@@ -0,0 +1,7 @@
from store.repositories.models.feedback import Feedback
from store.repositories.models.run import Run
from store.repositories.models.run_event import RunEvent
from store.repositories.models.thread_meta import ThreadMeta
from store.repositories.models.user import User
__all__ = ["Feedback", "Run", "RunEvent", "ThreadMeta", "User"]
@@ -0,0 +1,36 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import Integer, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Feedback(DataClassBase):
"""Feedback table (create-only, no updated_time)."""
__tablename__ = "feedback"
__table_args__ = (
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
{"comment": "Feedback table."},
)
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
run_id: Mapped[str] = mapped_column(String(64), index=True)
thread_id: Mapped[str] = mapped_column(String(64), index=True)
rating: Mapped[int] = mapped_column(Integer)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
message_id: Mapped[str | None] = mapped_column(String(64), default=None)
comment: Mapped[str | None] = mapped_column(UniversalText, default=None)
created_time: Mapped[datetime] = mapped_column(
"created_at",
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -0,0 +1,63 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, Index, Integer, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
class Run(DataClassBase):
"""Run metadata table."""
__tablename__ = "runs"
__table_args__ = (
Index("ix_runs_thread_status", "thread_id", "status"),
{"comment": "Run metadata table."},
)
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
thread_id: Mapped[str] = mapped_column(String(64), index=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
model_name: Mapped[str | None] = mapped_column(String(128), default=None)
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
error: Mapped[str | None] = mapped_column(UniversalText, default=None)
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64), default=None)
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
kwargs: Mapped[dict[str, Any]] = mapped_column("kwargs_json", JSON, default_factory=dict)
total_input_tokens: Mapped[int] = mapped_column(Integer, default=0)
total_output_tokens: Mapped[int] = mapped_column(Integer, default=0)
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
llm_call_count: Mapped[int] = mapped_column(Integer, default=0)
lead_agent_tokens: Mapped[int] = mapped_column(Integer, default=0)
subagent_tokens: Mapped[int] = mapped_column(Integer, default=0)
middleware_tokens: Mapped[int] = mapped_column(Integer, default=0)
message_count: Mapped[int] = mapped_column(Integer, default=0)
first_human_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
last_ai_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
created_time: Mapped[datetime] = mapped_column(
"created_at",
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
"updated_at",
TimeZone,
init=False,
default=None,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -0,0 +1,46 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import (
DataClassBase,
TimeZone,
UniversalText,
current_time,
id_key,
)
class RunEvent(DataClassBase):
"""Run event table."""
__tablename__ = "run_events"
__table_args__ = (
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
Index("ix_events_run", "thread_id", "run_id", "seq"),
{"comment": "Run event table."},
)
id: Mapped[id_key] = mapped_column(init=False)
thread_id: Mapped[str] = mapped_column(String(64), index=True)
run_id: Mapped[str] = mapped_column(String(64), index=True)
event_type: Mapped[str] = mapped_column(String(32), index=True)
category: Mapped[str] = mapped_column(String(16), index=True)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
content: Mapped[str] = mapped_column(UniversalText, default="")
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Event timestamp",
)
@@ -0,0 +1,43 @@
from __future__ import annotations
from datetime import datetime
from typing import Any
from sqlalchemy import JSON, String
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class ThreadMeta(DataClassBase):
"""Thread metadata table."""
__tablename__ = "threads_meta"
__table_args__ = {"comment": "Thread metadata table."}
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None, index=True)
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
display_name: Mapped[str | None] = mapped_column(String(256), default=None)
status: Mapped[str] = mapped_column(String(20), default="idle", index=True)
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
created_time: Mapped[datetime] = mapped_column(
"created_at",
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
updated_time: Mapped[datetime | None] = mapped_column(
"updated_at",
TimeZone,
init=False,
default=None,
onupdate=current_time,
sort_order=999,
comment="Updated at",
)
@@ -0,0 +1,42 @@
from __future__ import annotations
from datetime import datetime
from sqlalchemy import Boolean, Index, String, text
from sqlalchemy.orm import Mapped, mapped_column
from store.persistence.base_model import DataClassBase, TimeZone, current_time
class User(DataClassBase):
"""User account table."""
__tablename__ = "users"
__table_args__ = (
Index(
"idx_users_oauth_identity",
"oauth_provider",
"oauth_id",
unique=True,
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
),
{"comment": "User account table."},
)
id: Mapped[str] = mapped_column(String(36), primary_key=True)
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
system_role: Mapped[str] = mapped_column(String(16), default="user")
password_hash: Mapped[str | None] = mapped_column(String(128), default=None)
oauth_provider: Mapped[str | None] = mapped_column(String(32), default=None)
oauth_id: Mapped[str | None] = mapped_column(String(128), default=None)
needs_setup: Mapped[bool] = mapped_column(Boolean, default=False)
token_version: Mapped[int] = mapped_column(default=0)
created_at: Mapped[datetime] = mapped_column(
TimeZone,
init=False,
default_factory=current_time,
sort_order=999,
comment="Created at",
)
@@ -0,0 +1,3 @@
from .timezone import get_timezone
__all__ = ["get_timezone"]
@@ -0,0 +1,51 @@
import zoneinfo
from datetime import UTC, datetime
from store.config.app_config import get_app_config
# IANA identifiers that map to UTC — see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
_UTC_IDENTIFIERS = frozenset({"Etc/UCT", "Etc/Universal", "Etc/UTC", "Etc/Zulu", "UCT", "Universal", "UTC", "Zulu"})
class TimeZone:
def __init__(self) -> None:
app_config = get_app_config()
if app_config.timezone in _UTC_IDENTIFIERS:
self.tz_info = UTC
else:
self.tz_info = zoneinfo.ZoneInfo(app_config.timezone)
def now(self) -> datetime:
"""Return the current time in the configured timezone."""
return datetime.now(self.tz_info)
def from_datetime(self, t: datetime) -> datetime:
"""Convert a datetime to the configured timezone."""
return t.astimezone(self.tz_info)
def from_str(self, t_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime:
"""Parse a time string and attach the configured timezone."""
return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info)
@staticmethod
def to_str(t: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
"""Format a datetime to string."""
return t.strftime(format_str)
@staticmethod
def to_utc(t: datetime | int) -> datetime:
"""Convert a datetime or Unix timestamp to UTC."""
if isinstance(t, datetime):
return t.astimezone(UTC)
return datetime.fromtimestamp(t, tz=UTC)
_timezone = None
def get_timezone() -> TimeZone:
"""Return the global TimeZone singleton (lazy-initialized)."""
global _timezone
if _timezone is None:
_timezone = TimeZone()
return _timezone
+5 -3
View File
@@ -6,6 +6,7 @@ readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"deerflow-harness",
"deerflow-storage",
"fastapi>=0.115.0",
"httpx>=0.28.0",
"python-multipart>=0.0.27",
@@ -24,8 +25,8 @@ dependencies = [
]
[project.optional-dependencies]
postgres = ["deerflow-harness[postgres]"]
discord = ["discord.py>=2.7.0"]
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
mysql = ["deerflow-storage[mysql]"]
[dependency-groups]
dev = [
@@ -44,7 +45,8 @@ markers = [
index-url = "https://pypi.org/simple"
[tool.uv.workspace]
members = ["packages/harness"]
members = ["packages/harness", "packages/storage"]
[tool.uv.sources]
deerflow-harness = { workspace = true }
deerflow-storage = { workspace = true }
-15
View File
@@ -4,7 +4,6 @@ from pathlib import Path
import pytest
from _router_auth_helpers import call_unwrapped, make_authed_test_app
from fastapi import HTTPException
from fastapi.testclient import TestClient
from starlette.requests import Request
from starlette.responses import FileResponse
@@ -103,17 +102,3 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path
assert response.status_code == 200
assert response.text == "hello"
assert response.headers.get("content-disposition", "").startswith("attachment;")
def test_skill_archive_preview_rejects_oversized_member_before_decompression(tmp_path) -> None:
skill_path = tmp_path / "sample.skill"
payload = b"A" * (artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + 1)
with zipfile.ZipFile(skill_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zip_ref:
zip_ref.writestr("SKILL.md", payload)
assert skill_path.stat().st_size < artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES
with pytest.raises(HTTPException) as exc_info:
artifacts_router._extract_file_from_skill_archive(skill_path, "SKILL.md")
assert exc_info.value.status_code == 413
+11 -47
View File
@@ -5,26 +5,28 @@ from unittest.mock import patch
import pytest
import app.gateway.auth.config as cfg
from app.gateway.auth.config import AuthConfig
def test_auth_config_defaults():
config = cfg.AuthConfig(jwt_secret="test-secret-key-123")
config = AuthConfig(jwt_secret="test-secret-key-123")
assert config.token_expiry_days == 7
def test_auth_config_token_expiry_range():
cfg.AuthConfig(jwt_secret="s", token_expiry_days=1)
cfg.AuthConfig(jwt_secret="s", token_expiry_days=30)
AuthConfig(jwt_secret="s", token_expiry_days=1)
AuthConfig(jwt_secret="s", token_expiry_days=30)
with pytest.raises(Exception):
cfg.AuthConfig(jwt_secret="s", token_expiry_days=0)
AuthConfig(jwt_secret="s", token_expiry_days=0)
with pytest.raises(Exception):
cfg.AuthConfig(jwt_secret="s", token_expiry_days=31)
AuthConfig(jwt_secret="s", token_expiry_days=31)
def test_auth_config_from_env():
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
with patch.dict(os.environ, env, clear=False):
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
try:
@@ -34,57 +36,19 @@ def test_auth_config_from_env():
cfg._auth_config = old
def test_auth_config_missing_secret_generates_and_persists(tmp_path, caplog):
def test_auth_config_missing_secret_generates_ephemeral(caplog):
import logging
from deerflow.config.paths import Paths
import app.gateway.auth.config as cfg
old = cfg._auth_config
cfg._auth_config = None
secret_file = tmp_path / ".jwt_secret"
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)), caplog.at_level(logging.WARNING):
with caplog.at_level(logging.WARNING):
config = cfg.get_auth_config()
assert config.jwt_secret
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
assert secret_file.exists()
assert secret_file.read_text().strip() == config.jwt_secret
finally:
cfg._auth_config = old
def test_auth_config_reuses_persisted_secret(tmp_path):
from deerflow.config.paths import Paths
old = cfg._auth_config
cfg._auth_config = None
persisted = "persisted-secret-from-file-min-32-chars!!"
(tmp_path / ".jwt_secret").write_text(persisted, encoding="utf-8")
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
config = cfg.get_auth_config()
assert config.jwt_secret == persisted
finally:
cfg._auth_config = old
def test_auth_config_empty_secret_file_generates_new(tmp_path):
from deerflow.config.paths import Paths
old = cfg._auth_config
cfg._auth_config = None
(tmp_path / ".jwt_secret").write_text("", encoding="utf-8")
try:
with patch.dict(os.environ, {}, clear=True):
os.environ.pop("AUTH_JWT_SECRET", None)
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
config = cfg.get_auth_config()
assert config.jwt_secret
assert len(config.jwt_secret) > 20
assert (tmp_path / ".jwt_secret").read_text().strip() == config.jwt_secret
finally:
cfg._auth_config = old
+1 -1
View File
@@ -761,7 +761,7 @@ class TestChannelManager:
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
async def _runs_wait(thread_id, assistant_id, *, input, config, context, multitask_strategy=None):
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
del assistant_id, context # unused in this test, kept for signature parity
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
+5 -2
View File
@@ -94,12 +94,15 @@ class TestHarnessPackaging:
"psycopg-pool>=3.3.0",
]
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
def test_workspace_pyproject_forwards_postgres_extra_to_storage_packages(self):
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
data = tomllib.loads(pyproject_path.read_text())
optional_dependencies = data["project"]["optional-dependencies"]
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
assert optional_dependencies["postgres"] == [
"deerflow-harness[postgres]",
"deerflow-storage[postgres]",
]
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
@@ -158,88 +158,6 @@ class TestBuildPatchedMessagesPatching:
assert patched[1].name == "bash"
assert patched[1].status == "error"
def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self):
middleware = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
HumanMessage(content="interruption"),
_tool_msg("call_1", "bash"),
]
patched = middleware._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert isinstance(patched[2], HumanMessage)
def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
HumanMessage(content="interruption"),
_tool_msg("call_2", "read"),
_tool_msg("call_1", "bash"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert isinstance(patched[2], ToolMessage)
assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"]
assert isinstance(patched[3], HumanMessage)
def test_valid_adjacent_tool_results_are_unchanged(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
_tool_msg("call_1", "bash"),
HumanMessage(content="next"),
]
assert mw._build_patched_messages(msgs) is None
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
mw = DanglingToolCallMiddleware()
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
HumanMessage(content="interruption"),
_ai_with_tool_calls([_tc("read", "call_2")]),
_tool_msg("call_1", "bash"),
_tool_msg("call_2", "read"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert isinstance(patched[2], HumanMessage)
assert isinstance(patched[3], AIMessage)
assert isinstance(patched[4], ToolMessage)
assert patched[4].tool_call_id == "call_2"
def test_orphan_tool_message_is_preserved_during_grouping(self):
mw = DanglingToolCallMiddleware()
orphan = _tool_msg("orphan_call", "orphan")
msgs = [
_ai_with_tool_calls([_tc("bash", "call_1")]),
orphan,
HumanMessage(content="interruption"),
_tool_msg("call_1", "bash"),
]
patched = mw._build_patched_messages(msgs)
assert patched is not None
assert isinstance(patched[0], AIMessage)
assert isinstance(patched[1], ToolMessage)
assert patched[1].tool_call_id == "call_1"
assert orphan in patched
assert patched.count(orphan) == 1
def test_invalid_tool_call_is_patched(self):
mw = DanglingToolCallMiddleware()
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
+1
View File
@@ -454,6 +454,7 @@ class TestAStream:
@pytest.mark.asyncio
async def test_with_tools_emits_tool_call_chunk(self):
tool_calls = [{"name": "fn", "args": {}, "id": "c1"}]
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls)
+81
View File
@@ -0,0 +1,81 @@
from __future__ import annotations
import os
from pathlib import Path
import pytest
from sqlalchemy import Column, MetaData, String, Table
from sqlalchemy.dialects import mysql, postgresql
from sqlalchemy.types import JSON
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from store.persistence.json_compat import json_match
def _table():
metadata = MetaData()
return Table("t", metadata, Column("data", JSON), Column("id", String))
def test_storage_json_match_compiles_sqlite() -> None:
from sqlalchemy import create_engine
table = _table()
dialect = create_engine("sqlite://").dialect
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'null'")
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'true'")
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "= 'integer'" in int_sql
assert "CAST" in int_sql
float_sql = str(json_match(table.c.data, "k", 3.14).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "IN ('integer', 'real')" in float_sql
assert "REAL" in float_sql
def test_storage_json_match_compiles_postgres() -> None:
table = _table()
dialect = postgresql.dialect()
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_typeof(t.data -> 'k') = 'null'")
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')")
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "CASE WHEN" in int_sql
assert "BIGINT" in int_sql
assert "'^-?[0-9]+$'" in int_sql
def test_storage_json_match_compiles_mysql() -> None:
table = _table()
dialect = mysql.dialect()
null_sql = str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert null_sql == "JSON_TYPE(JSON_EXTRACT(t.data, '$.\"k\"')) = 'NULL'"
bool_sql = str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "JSON_TYPE(JSON_EXTRACT" in bool_sql
assert "= 'BOOLEAN'" in bool_sql
assert "= 'true'" in bool_sql
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
assert "= 'INTEGER'" in int_sql
assert "SIGNED" in int_sql
def test_storage_json_match_rejects_unsafe_keys_and_values() -> None:
table = _table()
for bad_key in ["a.b", "bad;key", "with space", "", 42, None]:
with pytest.raises(ValueError, match="JsonMatch key must match"):
json_match(table.c.data, bad_key, "x") # type: ignore[arg-type]
for bad_value in [[], {}, object()]:
with pytest.raises(TypeError, match="JsonMatch value must be"):
json_match(table.c.data, "k", bad_value)
with pytest.raises(TypeError, match="out of signed 64-bit range"):
json_match(table.c.data, "k", 2**63)
@@ -0,0 +1,122 @@
from __future__ import annotations
import os
import subprocess
import sys
from pathlib import Path
from types import SimpleNamespace
import pytest
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from store.config.storage_config import StorageConfig
from store.persistence.factory import _create_database_url, storage_config_from_database_config
def test_database_sqlite_config_maps_to_storage_config(tmp_path):
database = SimpleNamespace(
backend="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=True,
pool_size=9,
)
storage = storage_config_from_database_config(database)
assert storage == StorageConfig(
driver="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=True,
pool_size=9,
)
assert storage.sqlite_storage_path == str(tmp_path / "deerflow.db")
def test_database_memory_config_is_not_a_storage_backend():
database = SimpleNamespace(backend="memory")
with pytest.raises(ValueError, match="Unsupported database backend"):
storage_config_from_database_config(database)
def test_database_postgres_config_preserves_url_and_pool_options():
database = SimpleNamespace(
backend="postgres",
postgres_url="postgresql://user:pass@db.example:5544/deerflow",
echo_sql=True,
pool_size=11,
)
storage = storage_config_from_database_config(database)
url = _create_database_url(storage)
assert storage.driver == "postgres"
assert storage.database_url == "postgresql://user:pass@db.example:5544/deerflow"
assert storage.username == "user"
assert storage.password == "pass"
assert storage.host == "db.example"
assert storage.port == 5544
assert storage.db_name == "deerflow"
assert storage.echo_sql is True
assert storage.pool_size == 11
assert url.drivername == "postgresql+asyncpg"
assert url.database == "deerflow"
def test_mysql_database_url_is_normalized_to_async_driver():
storage = StorageConfig(
driver="mysql",
database_url="mysql://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+aiomysql"
assert url.database == "deerflow"
def test_mysql_async_database_url_is_preserved():
storage = StorageConfig(
driver="mysql",
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
)
url = _create_database_url(storage)
assert url.drivername == "mysql+asyncmy"
assert url.database == "deerflow"
def test_database_postgres_requires_url():
database = SimpleNamespace(backend="postgres", postgres_url="")
with pytest.raises(ValueError, match="database.postgres_url is required"):
storage_config_from_database_config(database)
def test_unsupported_database_backend_rejected():
database = SimpleNamespace(backend="oracle")
with pytest.raises(ValueError, match="Unsupported database backend"):
storage_config_from_database_config(database)
def test_storage_models_import_without_config_file(tmp_path):
env = os.environ.copy()
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
result = subprocess.run(
[
sys.executable,
"-c",
"from store.persistence.base_model import UniversalText, id_key; from store.repositories.models import RunEvent; print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
],
check=False,
capture_output=True,
env=env,
text=True,
)
assert result.returncode == 0, result.stderr
assert "UniversalText run_events" in result.stdout
@@ -0,0 +1,58 @@
from __future__ import annotations
import asyncio
import os
from pathlib import Path
from types import SimpleNamespace
from uuid import uuid4
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from sqlalchemy import inspect
from store.persistence import create_persistence_from_database_config
from store.repositories import UserCreate, build_user_repository
def test_sqlite_persistence_from_database_config_creates_storage_tables(tmp_path):
async def run() -> None:
persistence = await create_persistence_from_database_config(
SimpleNamespace(
backend="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=False,
pool_size=5,
)
)
assert persistence is not None
try:
await persistence.setup()
async with persistence.engine.connect() as conn:
tables = await conn.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names()))
assert {
"users",
"runs",
"run_events",
"threads_meta",
"feedback",
}.issubset(tables)
async with persistence.session_factory() as session:
repo = build_user_repository(session)
user = await repo.create_user(
UserCreate(
id=str(uuid4()),
email="storage-user@example.com",
password_hash="hash",
)
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_user_repository(session)
assert await repo.get_user_by_id(user.id) == user
finally:
await persistence.aclose()
asyncio.run(run())
+395
View File
@@ -0,0 +1,395 @@
from __future__ import annotations
import os
from datetime import UTC, datetime, timedelta
from pathlib import Path
from types import SimpleNamespace
import pytest
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from store.persistence import create_persistence_from_database_config
from store.repositories import (
FeedbackCreate,
InvalidMetadataFilterError,
RunCreate,
RunEventCreate,
ThreadMetaCreate,
build_feedback_repository,
build_run_event_repository,
build_run_repository,
build_thread_meta_repository,
)
async def _make_persistence(tmp_path):
persistence = await create_persistence_from_database_config(
SimpleNamespace(
backend="sqlite",
sqlite_dir=str(tmp_path),
echo_sql=False,
pool_size=5,
)
)
await persistence.setup()
return persistence
@pytest.mark.anyio
async def test_storage_run_repository_filters_and_aggregates(tmp_path):
persistence = await _make_persistence(tmp_path)
old = datetime.now(UTC) - timedelta(hours=1)
newer = datetime.now(UTC)
try:
async with persistence.session_factory() as session:
repo = build_run_repository(session)
await repo.create_run(
RunCreate(
run_id="run-old",
thread_id="thread-1",
user_id="alice",
status="pending",
model_name="model-a",
metadata={"kind": "draft"},
kwargs={"temperature": 0.2},
created_time=old,
)
)
await repo.create_run(
RunCreate(
run_id="run-new",
thread_id="thread-1",
user_id="bob",
status="running",
model_name="model-b",
error="queued",
created_time=newer,
)
)
await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running"))
await repo.update_run_completion(
"run-old",
status="success",
total_input_tokens=7,
total_output_tokens=3,
total_tokens=10,
llm_call_count=1,
lead_agent_tokens=8,
subagent_tokens=2,
first_human_message="hello",
last_ai_message="world",
)
await repo.update_run_completion(
"run-new",
status="error",
total_tokens=5,
middleware_tokens=5,
error="failed",
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_run_repository(session)
fetched = await repo.get_run("run-old")
assert fetched is not None
assert fetched.metadata == {"kind": "draft"}
assert fetched.kwargs == {"temperature": 0.2}
assert fetched.first_human_message == "hello"
assert fetched.last_ai_message == "world"
all_thread_runs = await repo.list_runs_by_thread("thread-1")
assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"]
alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice")
assert [run.run_id for run in alice_runs] == ["run-old"]
pending = await repo.list_pending(before=datetime.now(UTC).isoformat())
assert [run.run_id for run in pending] == []
agg = await repo.aggregate_tokens_by_thread("thread-1")
assert agg["total_tokens"] == 15
assert agg["total_input_tokens"] == 7
assert agg["total_output_tokens"] == 3
assert agg["total_runs"] == 2
assert agg["by_model"] == {
"model-a": {"tokens": 10, "runs": 1},
"model-b": {"tokens": 5, "runs": 1},
}
assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5}
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
await repo.create_thread_meta(
ThreadMetaCreate(
thread_id="thread-1",
assistant_id="agent-a",
user_id="alice",
display_name="Initial",
status="idle",
metadata={"topic": "finance", "region": "cn"},
)
)
await repo.create_thread_meta(
ThreadMetaCreate(
thread_id="thread-2",
assistant_id="agent-b",
user_id="bob",
status="running",
metadata={"topic": "legal"},
)
)
await repo.update_thread_meta(
"thread-1",
display_name="Updated",
status="running",
metadata={"topic": "finance", "region": "us"},
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
fetched = await repo.get_thread_meta("thread-1")
assert fetched is not None
assert fetched.display_name == "Updated"
assert fetched.status == "running"
assert fetched.metadata == {"topic": "finance", "region": "us"}
by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice")
assert [thread.thread_id for thread in by_metadata] == ["thread-1"]
by_assistant = await repo.search_threads(assistant_id="agent-b")
assert [thread.thread_id for thread in by_assistant] == ["thread-2"]
await repo.delete_thread("thread-1")
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
assert await repo.get_thread_meta("thread-1") is None
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_thread_meta_metadata_filters_are_type_safe(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-true", metadata={"value": True}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-false", metadata={"value": False}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="int-one", metadata={"value": 1}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="null-value", metadata={"value": None}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="missing-value", metadata={"other": "x"}))
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
assert [row.thread_id for row in await repo.search_threads(metadata={"value": True})] == ["bool-true"]
assert [row.thread_id for row in await repo.search_threads(metadata={"value": False})] == ["bool-false"]
assert [row.thread_id for row in await repo.search_threads(metadata={"value": 1})] == ["int-one"]
assert [row.thread_id for row in await repo.search_threads(metadata={"value": None})] == ["null-value"]
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_thread_meta_metadata_filters_paginate_after_sql_match(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
for index in range(30):
metadata = {"target": "yes"} if index % 3 == 0 else {"target": "no"}
await repo.create_thread_meta(ThreadMetaCreate(thread_id=f"thread-{index:02d}", metadata=metadata))
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
first_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=0)
second_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=3)
last_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=9)
assert len(first_page) == 3
assert len(second_page) == 3
assert len(last_page) == 1
assert {row.thread_id for row in first_page}.isdisjoint({row.thread_id for row in second_page})
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_thread_meta_metadata_filter_rejects_invalid_entries(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-1", metadata={"env": "prod"}))
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-2", metadata={"env": "staging"}))
await session.commit()
async with persistence.session_factory() as session:
repo = build_thread_meta_repository(session)
partial = await repo.search_threads(metadata={"env": "prod", "bad;key": "ignored"})
assert [row.thread_id for row in partial] == ["thread-1"]
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search_threads(metadata={"bad;key": "x"})
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
await repo.search_threads(metadata={"env": ["prod", "staging"]})
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_feedback_repository(session)
first = await repo.create_feedback(
FeedbackCreate(
feedback_id="fb-1",
run_id="run-1",
thread_id="thread-1",
rating=1,
user_id="alice",
message_id="msg-1",
comment="good",
)
)
second = await repo.create_feedback(
FeedbackCreate(
feedback_id="fb-2",
run_id="run-1",
thread_id="thread-1",
rating=-1,
user_id="bob",
)
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_feedback_repository(session)
assert await repo.get_feedback(first.feedback_id) == first
assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [
second.feedback_id,
first.feedback_id,
]
assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == {
"fb-1",
"fb-2",
}
assert await repo.delete_feedback("fb-1") is True
assert await repo.delete_feedback("missing") is False
with pytest.raises(ValueError, match="rating must be"):
await repo.create_feedback(
FeedbackCreate(
feedback_id="fb-bad",
run_id="run-1",
thread_id="thread-1",
rating=0,
)
)
await session.commit()
async with persistence.session_factory() as session:
repo = build_feedback_repository(session)
assert await repo.get_feedback("fb-1") is None
finally:
await persistence.aclose()
@pytest.mark.anyio
async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path):
persistence = await _make_persistence(tmp_path)
try:
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
rows = await repo.append_batch(
[
RunEventCreate(
thread_id="thread-1",
run_id="run-1",
user_id="alice",
event_type="message",
category="message",
content={"role": "user", "content": "hello"},
metadata={"source": "input"},
),
RunEventCreate(
thread_id="thread-1",
run_id="run-1",
event_type="tool",
category="debug",
content="tool-call",
),
RunEventCreate(
thread_id="thread-1",
run_id="run-2",
event_type="message",
category="message",
content="second",
),
RunEventCreate(
thread_id="thread-2",
run_id="run-3",
event_type="message",
category="message",
content="other-thread",
),
]
)
await session.commit()
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
assert rows[1].seq == rows[0].seq + 1
assert rows[2].seq == rows[1].seq + 1
assert rows[0].content == {"role": "user", "content": "hello"}
assert rows[0].metadata == {"source": "input", "content_is_json": True}
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
messages = await repo.list_messages("thread-1", limit=2)
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
assert await repo.count_messages("thread-1") == 2
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
assert [event.seq for event in after] == [rows[0].seq]
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
assert [event.seq for event in before] == [rows[0].seq]
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
assert [event.content for event in events] == ["tool-call"]
assert await repo.delete_by_run("thread-1", "run-1") == 2
assert await repo.delete_by_thread("thread-2") == 1
await session.commit()
async with persistence.session_factory() as session:
repo = build_run_event_repository(session)
remaining = await repo.list_events("thread-1", "run-2")
assert [event.seq for event in remaining] == [rows[2].seq]
assert await repo.count_messages("thread-2") == 0
later = await repo.append_batch(
[
RunEventCreate(
thread_id="thread-1",
run_id="run-4",
event_type="message",
category="message",
content="after-delete",
)
]
)
assert later[0].seq > rows[2].seq
finally:
await persistence.aclose()
@@ -0,0 +1,177 @@
from __future__ import annotations
import asyncio
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from uuid import uuid4
import pytest
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from store.repositories import UserCreate, UserNotFoundError, build_user_repository
from store.repositories.models import User as UserModel
@asynccontextmanager
async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]:
db_path = tmp_path / "storage-users.db"
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
async with engine.begin() as conn:
await conn.run_sync(UserModel.metadata.create_all)
try:
yield async_sessionmaker(engine, expire_on_commit=False)
finally:
await engine.dispose()
async def _create_user(
session_factory: async_sessionmaker[AsyncSession],
*,
email: str = "user@example.com",
system_role: str = "user",
oauth_provider: str | None = None,
oauth_id: str | None = None,
):
async with session_factory() as session:
repo = build_user_repository(session)
user = await repo.create_user(
UserCreate(
id=str(uuid4()),
email=email,
password_hash="hash",
system_role=system_role, # type: ignore[arg-type]
oauth_provider=oauth_provider,
oauth_id=oauth_id,
)
)
await session.commit()
return user
def test_create_and_get_user_by_id_and_email(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
created = await _create_user(session_factory)
async with session_factory() as session:
repo = build_user_repository(session)
by_id = await repo.get_user_by_id(created.id)
by_email = await repo.get_user_by_email(created.email)
assert by_id == created
assert by_email == created
assert created.system_role == "user"
assert created.needs_setup is False
assert created.token_version == 0
asyncio.run(run())
def test_duplicate_email_raises_value_error(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
await _create_user(session_factory, email="dupe@example.com")
async with session_factory() as session:
repo = build_user_repository(session)
with pytest.raises(ValueError, match="Email already registered"):
await repo.create_user(
UserCreate(
id=str(uuid4()),
email="dupe@example.com",
password_hash="hash",
)
)
asyncio.run(run())
def test_oauth_lookup_and_plain_users_without_oauth(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
await _create_user(session_factory, email="local-1@example.com")
await _create_user(session_factory, email="local-2@example.com")
oauth_user = await _create_user(
session_factory,
email="oauth@example.com",
oauth_provider="github",
oauth_id="gh-123",
)
async with session_factory() as session:
repo = build_user_repository(session)
assert await repo.count_users() == 3
assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user
assert await repo.get_user_by_oauth("github", "missing") is None
asyncio.run(run())
def test_count_admins_and_get_first_admin(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
await _create_user(session_factory, email="user@example.com")
admin = await _create_user(
session_factory,
email="admin@example.com",
system_role="admin",
)
async with session_factory() as session:
repo = build_user_repository(session)
assert await repo.count_users() == 2
assert await repo.count_admin_users() == 1
assert await repo.get_first_admin() == admin
asyncio.run(run())
def test_update_user_round_trips_token_version_and_setup_state(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
created = await _create_user(session_factory)
updated = created.model_copy(
update={
"email": "renamed@example.com",
"token_version": 4,
"needs_setup": True,
}
)
async with session_factory() as session:
repo = build_user_repository(session)
saved = await repo.update_user(updated)
await session.commit()
async with session_factory() as session:
repo = build_user_repository(session)
fetched = await repo.get_user_by_id(created.id)
assert saved.email == "renamed@example.com"
assert fetched == updated
asyncio.run(run())
def test_update_missing_user_raises(tmp_path):
async def run() -> None:
async with _session_factory(tmp_path) as session_factory:
missing = UserCreate(id=str(uuid4()), email="missing@example.com")
async with session_factory() as session:
repo = build_user_repository(session)
created_shape = await repo.create_user(missing)
await session.rollback()
with pytest.raises(UserNotFoundError):
await repo.update_user(created_shape)
asyncio.run(run())
+1 -79
View File
@@ -56,8 +56,7 @@ def _middleware(
preserve_recent_skill_tokens_per_skill: int = 0,
) -> DeerFlowSummarizationMiddleware:
model = MagicMock()
model.invoke.return_value = AIMessage(content="compressed summary")
model.with_config.return_value.invoke.return_value = AIMessage(content="compressed summary")
model.invoke.return_value = SimpleNamespace(text="compressed summary")
return DeerFlowSummarizationMiddleware(
model=model,
trigger=trigger,
@@ -643,69 +642,6 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
# ---------------------------------------------------------------------------
# Issue #2804: summary text must not leak to the frontend via streaming
# ---------------------------------------------------------------------------
def test_build_new_messages_sets_hide_from_ui() -> None:
"""The summary HumanMessage must carry hide_from_ui so the frontend filters it."""
middleware = _middleware()
messages = middleware._build_new_messages("test summary")
assert len(messages) == 1
msg = messages[0]
assert msg.name == "summary"
assert msg.additional_kwargs.get("hide_from_ui") is True
assert "test summary" in msg.content
def test_create_summary_suppresses_callbacks() -> None:
"""_create_summary must bind callbacks=[] on the model AND pass callbacks=[]
in the invoke config to suppress inherited LangGraph stream callbacks."""
middleware = _middleware()
middleware._create_summary(_messages())
middleware.model.with_config.assert_called_once_with(callbacks=[])
bound = middleware.model.with_config.return_value
bound.invoke.assert_called_once()
call_config = bound.invoke.call_args.kwargs.get("config") or bound.invoke.call_args[1].get("config")
assert call_config is not None
assert call_config.get("callbacks") == []
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
@pytest.mark.anyio
async def test_acreate_summary_suppresses_callbacks() -> None:
"""_acreate_summary must bind callbacks=[] on the model AND pass callbacks=[]
in the ainvoke config to suppress inherited LangGraph stream callbacks."""
middleware = _middleware()
middleware.model.with_config.return_value.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
await middleware._acreate_summary(_messages())
middleware.model.with_config.assert_called_once_with(callbacks=[])
bound = middleware.model.with_config.return_value
bound.ainvoke.assert_called_once()
call_config = bound.ainvoke.call_args.kwargs.get("config") or bound.ainvoke.call_args[1].get("config")
assert call_config is not None
assert call_config.get("callbacks") == []
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
def test_before_model_summary_message_has_hide_from_ui() -> None:
"""End-to-end: the emitted state update contains a summary message with hide_from_ui."""
middleware = _middleware()
result = middleware.before_model({"messages": _messages()}, _runtime())
emitted = result["messages"]
summary_msg = emitted[1]
assert summary_msg.name == "summary"
assert summary_msg.additional_kwargs.get("hide_from_ui") is True
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
queue = MagicMock()
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
@@ -723,17 +659,3 @@ def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatc
queue.add_nowait.assert_called_once()
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
def test_extract_summary_text_normalizes_list_content_blocks() -> None:
"""AIMessage.content can be a list of content blocks; _extract_summary_text
must normalize to plain text via the .text property instead of producing
a Python repr like [{'type': 'text', 'text': 'summary'}]."""
middleware = _middleware()
response = AIMessage(content=[{"type": "text", "text": "A summary of the chat."}])
assert middleware._extract_summary_text(response) == "A summary of the chat."
# Plain string content still works
response_str = AIMessage(content="Plain summary")
assert middleware._extract_summary_text(response_str) == "Plain summary"
+3 -2
View File
@@ -1214,11 +1214,12 @@ def test_terminal_event_usage_none_when_no_records(monkeypatch):
assert completed[0]["usage"] is None
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
@pytest.mark.parametrize("error", [FileNotFoundError("missing config"), ValueError("invalid config")])
def test_subagent_usage_cache_is_skipped_when_default_config_cannot_load(monkeypatch, error):
monkeypatch.setattr(
task_tool_module,
"get_app_config",
MagicMock(side_effect=FileNotFoundError("missing config")),
MagicMock(side_effect=error),
)
assert task_tool_module._token_usage_cache_enabled(None) is False
@@ -93,7 +93,7 @@ class TestTitleMiddlewareCoreLogic:
assert middleware._should_generate_title(state) is False
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
_set_test_title_config(max_chars=12, model_name=None)
_set_test_title_config(max_chars=12)
middleware = TitleMiddleware()
model = MagicMock()
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
+15 -359
View File
@@ -1,19 +1,14 @@
"""Tests for TodoMiddleware context-loss detection."""
import asyncio
from typing import Any
from unittest.mock import AsyncMock, MagicMock
from unittest.mock import MagicMock
from langchain.agents import create_agent
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
from langchain_core.messages import AIMessage, HumanMessage
from pydantic import PrivateAttr
from deerflow.agents.middlewares.todo_middleware import (
TodoMiddleware,
_completion_reminder_count,
_format_todos,
_has_tool_call_intent_or_error,
_reminder_in_messages,
_todos_in_messages,
)
@@ -27,35 +22,9 @@ def _reminder_msg():
return HumanMessage(name="todo_reminder", content="reminder")
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
@property
def seen_messages(self) -> list[list[Any]]:
return self._seen_messages
def bind_tools(self, tools, *, tool_choice=None, **kwargs):
return self
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
self._seen_messages.append(list(messages))
return super()._generate(
messages,
stop=stop,
run_manager=run_manager,
**kwargs,
)
def _make_runtime():
runtime = MagicMock()
runtime.context = {"thread_id": "test-thread", "run_id": "test-run"}
return runtime
def _make_runtime_for(thread_id: str, run_id: str):
runtime = _make_runtime()
runtime.context = {"thread_id": thread_id, "run_id": run_id}
runtime.context = {"thread_id": "test-thread"}
return runtime
@@ -192,62 +161,10 @@ def _completion_reminder_msg():
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
def _todo_completion_reminders(messages):
reminders = []
for message in messages:
if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder":
reminders.append(message)
return reminders
def _ai_no_tool_calls():
return AIMessage(content="I'm done!")
def _ai_with_invalid_tool_calls():
return AIMessage(
content="",
tool_calls=[],
invalid_tool_calls=[
{
"type": "invalid_tool_call",
"id": "write_file:36",
"name": "write_file",
"args": "{invalid",
"error": "Failed to parse tool arguments",
}
],
)
def _ai_with_raw_provider_tool_calls():
return AIMessage(
content="",
tool_calls=[],
invalid_tool_calls=[],
additional_kwargs={
"tool_calls": [
{
"id": "raw-tool-call",
"type": "function",
"function": {"name": "write_file", "arguments": '{"path":"report.md"}'},
}
]
},
)
def _ai_with_legacy_function_call():
return AIMessage(
content="",
additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}},
)
def _ai_with_tool_finish_reason():
return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"})
def _incomplete_todos():
return [
{"status": "completed", "content": "Step 1"},
@@ -277,36 +194,6 @@ class TestCompletionReminderCount:
assert _completion_reminder_count(msgs) == 1
class TestToolCallIntentOrError:
def test_false_for_plain_final_answer(self):
assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False
def test_true_for_structured_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True
def test_true_for_invalid_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True
def test_true_for_raw_provider_tool_calls(self):
assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True
def test_true_for_legacy_function_call(self):
assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True
def test_true_for_tool_finish_reason(self):
assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True
def test_langchain_ai_message_tool_fields_are_explicitly_handled(self):
# Sentinel for LangChain compatibility: if future AIMessage versions add
# new top-level tool/function-call fields, this test should fail. When
# it does, update `_has_tool_call_intent_or_error()` so the completion
# reminder guard explicitly decides whether each new field means "not a
# clean final answer"; the helper has a matching comment pointing back
# to this sentinel.
tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())}
assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"}
class TestAfterModel:
def test_returns_none_when_agent_still_using_tools(self):
mw = TodoMiddleware()
@@ -348,299 +235,68 @@ class TestAfterModel:
}
assert mw.after_model(state, _make_runtime()) is None
def test_queues_reminder_and_jumps_to_model_when_incomplete(self):
def test_injects_reminder_and_jumps_to_model_when_incomplete(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
result = mw.after_model(state, runtime)
result = mw.after_model(state, _make_runtime())
assert result is not None
assert result["jump_to"] == "model"
assert "messages" not in result
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
request.override.assert_called_once()
reminder = request.override.call_args.kwargs["messages"][-1]
assert len(result["messages"]) == 1
reminder = result["messages"][0]
assert isinstance(reminder, HumanMessage)
assert reminder.name == "todo_completion_reminder"
assert reminder.additional_kwargs["hide_from_ui"] is True
assert "Step 2" in reminder.content
assert "Step 3" in reminder.content
handler.assert_called_once_with("patched-request")
def test_reminder_lists_only_incomplete_items(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
result = mw.after_model(state, runtime)
assert result is not None
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
mw.wrap_model_call(request, MagicMock(return_value="response"))
content = request.override.call_args.kwargs["messages"][-1].content
result = mw.after_model(state, _make_runtime())
content = result["messages"][0].content
assert "Step 1" not in content # completed — should not appear
assert "Step 2" in content
assert "Step 3" in content
def test_allows_exit_after_max_reminders(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [
_completion_reminder_msg(),
_completion_reminder_msg(),
_ai_no_tool_calls(),
],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, runtime) is not None
assert mw.after_model(state, runtime) is not None
assert mw.after_model(state, runtime) is None
assert mw.after_model(state, _make_runtime()) is None
def test_still_sends_reminder_before_cap(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [
_completion_reminder_msg(), # 1 reminder so far
_ai_no_tool_calls(),
],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, runtime) is not None
result = mw.after_model(state, runtime)
result = mw.after_model(state, _make_runtime())
assert result is not None
assert result["jump_to"] == "model"
def test_does_not_trigger_for_invalid_tool_calls(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_invalid_tool_calls()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
def test_does_not_trigger_for_raw_provider_tool_calls(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_raw_provider_tool_calls()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
def test_does_not_trigger_for_legacy_function_call(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_legacy_function_call()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
def test_does_not_trigger_for_tool_finish_reason(self):
mw = TodoMiddleware()
state = {
"messages": [_ai_with_tool_finish_reason()],
"todos": _incomplete_todos(),
}
assert mw.after_model(state, _make_runtime()) is None
class TestAafterModel:
def test_delegates_to_sync(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
result = asyncio.run(mw.aafter_model(state, runtime))
result = asyncio.run(mw.aafter_model(state, _make_runtime()))
assert result is not None
assert result["jump_to"] == "model"
assert "messages" not in result
class TestWrapModelCall:
def test_no_pending_reminder_passthrough(self):
mw = TodoMiddleware()
request = MagicMock()
request.runtime = _make_runtime()
request.messages = [HumanMessage(content="hi")]
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
request.override.assert_not_called()
handler.assert_called_once_with(request)
def test_pending_reminder_is_injected_once(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
mw.after_model(state, runtime)
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = MagicMock(return_value="response")
assert mw.wrap_model_call(request, handler) == "response"
injected_messages = request.override.call_args.kwargs["messages"]
assert injected_messages[-1].name == "todo_completion_reminder"
request.override.reset_mock()
handler.reset_mock()
handler.return_value = "second-response"
assert mw.wrap_model_call(request, handler) == "second-response"
request.override.assert_not_called()
handler.assert_called_once_with(request)
class TestTodoMiddlewareAgentGraphIntegration:
def test_completion_reminder_is_transient_in_real_agent_graph(self):
mw = TodoMiddleware()
model = _CapturingFakeMessagesListChatModel(
responses=[
AIMessage(
content="",
tool_calls=[
{
"name": "write_todos",
"id": "todos-1",
"args": {
"todos": [
{"content": "Step 1", "status": "completed"},
{"content": "Step 2", "status": "pending"},
]
},
}
],
),
AIMessage(content="premature final 1"),
AIMessage(content="premature final 2"),
AIMessage(content="premature final 3"),
],
)
graph = create_agent(model=model, tools=[], middleware=[mw])
result = graph.invoke(
{"messages": [("user", "finish all todos")]},
context={"thread_id": "integration-thread", "run_id": "integration-run"},
)
assert len(model.seen_messages) == 4
reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages]
assert reminders_by_call[0] == []
assert reminders_by_call[1] == []
assert len(reminders_by_call[2]) == 1
assert len(reminders_by_call[3]) == 1
assert "Step 1" not in reminders_by_call[2][0].content
assert "Step 2" in reminders_by_call[2][0].content
persisted_reminders = _todo_completion_reminders(result["messages"])
assert persisted_reminders == []
assert result["messages"][-1].content == "premature final 3"
assert result["todos"] == [
{"content": "Step 1", "status": "completed"},
{"content": "Step 2", "status": "pending"},
]
assert mw._pending_completion_reminders == {}
assert mw._completion_reminder_counts == {}
class TestRunScopedReminderCleanup:
def test_before_agent_clears_stale_count_without_pending_reminder(self):
mw = TodoMiddleware()
stale_runtime = _make_runtime()
stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"}
current_runtime = _make_runtime()
current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"}
other_thread_runtime = _make_runtime()
other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"}
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, stale_runtime) is not None
assert mw.after_model(state, other_thread_runtime) is not None
# Simulate a model call that drained the pending message, followed by an
# abnormal run end where after_agent did not clear the reminder count.
assert mw._drain_completion_reminders(stale_runtime)
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1
mw.before_agent({}, current_runtime)
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1
def test_size_guard_prunes_oldest_count_only_reminder_state(self):
mw = TodoMiddleware()
mw._MAX_COMPLETION_REMINDER_KEYS = 2
first_runtime = _make_runtime_for("thread-a", "run-a")
second_runtime = _make_runtime_for("thread-b", "run-b")
third_runtime = _make_runtime_for("thread-c", "run-c")
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, first_runtime) is not None
# Simulate the normal model request path: pending reminder is consumed,
# but the run count remains until after_agent() or stale cleanup.
assert mw._drain_completion_reminders(first_runtime)
assert mw._completion_reminder_count_for_runtime(first_runtime) == 1
assert mw.after_model(state, second_runtime) is not None
assert mw.after_model(state, third_runtime) is not None
assert mw._completion_reminder_count_for_runtime(first_runtime) == 0
assert mw._completion_reminder_count_for_runtime(second_runtime) == 1
assert mw._completion_reminder_count_for_runtime(third_runtime) == 1
assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order
def test_size_guard_prunes_pending_and_count_state_together(self):
mw = TodoMiddleware()
mw._MAX_COMPLETION_REMINDER_KEYS = 1
stale_runtime = _make_runtime_for("thread-a", "run-a")
current_runtime = _make_runtime_for("thread-b", "run-b")
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
assert mw.after_model(state, stale_runtime) is not None
assert mw.after_model(state, current_runtime) is not None
assert mw._drain_completion_reminders(stale_runtime) == []
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
assert mw._completion_reminder_count_for_runtime(current_runtime) == 1
class TestAwrapModelCall:
def test_async_pending_reminder_is_injected(self):
mw = TodoMiddleware()
runtime = _make_runtime()
state = {
"messages": [_ai_no_tool_calls()],
"todos": _incomplete_todos(),
}
mw.after_model(state, runtime)
request = MagicMock()
request.runtime = runtime
request.messages = state["messages"]
request.override.return_value = "patched-request"
handler = AsyncMock(return_value="response")
result = asyncio.run(mw.awrap_model_call(request, handler))
assert result == "response"
injected_messages = request.override.call_args.kwargs["messages"]
assert injected_messages[-1].name == "todo_completion_reminder"
handler.assert_awaited_once_with("patched-request")
assert result["messages"][0].name == "todo_completion_reminder"
+93 -18
View File
@@ -1,5 +1,5 @@
version = 1
revision = 2
revision = 3
requires-python = ">=3.12"
resolution-markers = [
"python_full_version >= '3.14' and sys_platform == 'win32'",
@@ -14,6 +14,7 @@ resolution-markers = [
members = [
"deer-flow",
"deerflow-harness",
"deerflow-storage",
]
[[package]]
@@ -136,6 +137,18 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441, upload-time = "2026-03-31T22:00:12.791Z" },
]
[[package]]
name = "aiomysql"
version = "0.3.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pymysql" },
]
sdist = { url = "https://files.pythonhosted.org/packages/29/e0/302aeffe8d90853556f47f3106b89c16cc2ec2a4d269bdfd82e3f4ae12cc/aiomysql-0.3.2.tar.gz", hash = "sha256:72d15ef5cfc34c03468eb41e1b90adb9fd9347b0b589114bd23ead569a02ac1a", size = 108311, upload-time = "2025-10-22T00:15:21.278Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/4c/af/aae0153c3e28712adaf462328f6c7a3c196a1c1c27b491de4377dd3e6b52/aiomysql-0.3.2-py3-none-any.whl", hash = "sha256:c82c5ba04137d7afd5c693a258bea8ead2aad77101668044143a991e04632eb2", size = 71834, upload-time = "2025-10-22T00:15:15.905Z" },
]
[[package]]
name = "aiosignal"
version = "1.4.0"
@@ -746,6 +759,7 @@ source = { virtual = "." }
dependencies = [
{ name = "bcrypt" },
{ name = "deerflow-harness" },
{ name = "deerflow-storage" },
{ name = "dingtalk-stream" },
{ name = "email-validator" },
{ name = "fastapi" },
@@ -763,11 +777,12 @@ dependencies = [
]
[package.optional-dependencies]
discord = [
{ name = "discord-py" },
mysql = [
{ name = "deerflow-storage", extra = ["mysql"] },
]
postgres = [
{ name = "deerflow-harness", extra = ["postgres"] },
{ name = "deerflow-storage", extra = ["postgres"] },
]
[package.dev-dependencies]
@@ -783,8 +798,10 @@ requires-dist = [
{ name = "bcrypt", specifier = ">=4.0.0" },
{ name = "deerflow-harness", editable = "packages/harness" },
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
{ name = "deerflow-storage", editable = "packages/storage" },
{ name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" },
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
{ name = "discord-py", marker = "extra == 'discord'", specifier = ">=2.7.0" },
{ name = "email-validator", specifier = ">=2.0.0" },
{ name = "fastapi", specifier = ">=0.115.0" },
{ name = "httpx", specifier = ">=0.28.0" },
@@ -799,7 +816,7 @@ requires-dist = [
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
]
provides-extras = ["postgres", "discord"]
provides-extras = ["postgres", "mysql"]
[package.metadata.requires-dev]
dev = [
@@ -905,6 +922,54 @@ requires-dist = [
]
provides-extras = ["ollama", "postgres", "pymupdf"]
[[package]]
name = "deerflow-storage"
version = "0.1.0"
source = { editable = "packages/storage" }
dependencies = [
{ name = "alembic" },
{ name = "dotenv" },
{ name = "langgraph" },
{ name = "pydantic" },
{ name = "pyyaml" },
{ name = "sqlalchemy", extra = ["asyncio"] },
]
[package.optional-dependencies]
mysql = [
{ name = "aiomysql" },
{ name = "langgraph-checkpoint-mysql" },
]
postgres = [
{ name = "asyncpg" },
{ name = "langgraph-checkpoint-postgres" },
{ name = "psycopg", extra = ["binary"] },
{ name = "psycopg-pool" },
]
sqlite = [
{ name = "aiosqlite" },
{ name = "langgraph-checkpoint-sqlite" },
]
[package.metadata]
requires-dist = [
{ name = "aiomysql", marker = "extra == 'mysql'", specifier = ">=0.2" },
{ name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.22.1" },
{ name = "alembic", specifier = ">=1.13" },
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
{ name = "dotenv", specifier = ">=0.9.9" },
{ name = "langgraph", specifier = ">=1.1.9" },
{ name = "langgraph-checkpoint-mysql", marker = "extra == 'mysql'", specifier = ">=3.0.0" },
{ name = "langgraph-checkpoint-postgres", marker = "extra == 'postgres'", specifier = ">=3.0.5" },
{ name = "langgraph-checkpoint-sqlite", marker = "extra == 'sqlite'", specifier = ">=3.0.3" },
{ name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.3.3" },
{ name = "psycopg-pool", marker = "extra == 'postgres'", specifier = ">=3.3.0" },
{ name = "pydantic", specifier = ">=2.12.5" },
{ name = "pyyaml", specifier = ">=6.0.3" },
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<3.0" },
]
provides-extras = ["postgres", "mysql", "sqlite"]
[[package]]
name = "defusedxml"
version = "0.7.1"
@@ -927,19 +992,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/4c/44/102dede3f371277598df6aa9725b82e3add068c729333c7a5dbc12764579/dingtalk_stream-0.24.3-py3-none-any.whl", hash = "sha256:2160403656985962878bf60cdf5adf41619f21067348e06f07a7c7eebf5943ad", size = 27813, upload-time = "2025-10-24T09:36:57.497Z" },
]
[[package]]
name = "discord-py"
version = "2.7.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "aiohttp" },
{ name = "audioop-lts", marker = "python_full_version >= '3.13'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ef/57/9a2d9abdabdc9db8ef28ce0cf4129669e1c8717ba28d607b5ba357c4de3b/discord_py-2.7.1.tar.gz", hash = "sha256:24d5e6a45535152e4b98148a9dd6b550d25dc2c9fb41b6d670319411641249da", size = 1106326, upload-time = "2026-03-03T18:40:46.24Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/f7/a7/17208c3b3f92319e7fad259f1c6d5a5baf8fd0654c54846ced329f83c3eb/discord_py-2.7.1-py3-none-any.whl", hash = "sha256:849dca2c63b171146f3a7f3f8acc04248098e9e6203412ce3cf2745f284f7439", size = 1227550, upload-time = "2026-03-03T18:40:44.492Z" },
]
[[package]]
name = "distro"
version = "1.9.0"
@@ -1931,6 +1983,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/b9/5a/6dba29dd89b0a46ae21c707da0f9d17e94f27d3e481ed15bc99d6bd20aa6/langgraph_checkpoint-4.0.2-py3-none-any.whl", hash = "sha256:59b0f29216128a629c58dd07c98aa004f82f51805d5573126ffb419b753ff253", size = 51000, upload-time = "2026-04-15T21:02:59.096Z" },
]
[[package]]
name = "langgraph-checkpoint-mysql"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "langgraph-checkpoint" },
{ name = "orjson" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/e0/4e/0a6c78e5d3f2ca1525903c2363e721873594b6b77dd83537a6369193c474/langgraph_checkpoint_mysql-3.0.0.tar.gz", hash = "sha256:006aaa089f4c2fbd7b2c113b800ccd3dbb95f92203e656451677256b4b4f880f", size = 213142, upload-time = "2026-01-23T11:11:15.74Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/08/68/343103a7fae05523f9cecabbec2babdb737e66b4bf6ea48ae00c685ed11c/langgraph_checkpoint_mysql-3.0.0-py3-none-any.whl", hash = "sha256:7560ccd16e7596a047e15a307cec12dbd88fdcaab45a75759e5c6adef22a27d1", size = 38009, upload-time = "2026-01-23T11:11:14.697Z" },
]
[[package]]
name = "langgraph-checkpoint-postgres"
version = "3.0.5"
@@ -3459,6 +3525,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" },
]
[[package]]
name = "pymysql"
version = "1.1.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/7f/ec/8d45c920e90445f0b75c590b32851853ed319763b0d8dff8d283052da8cf/pymysql-1.1.3.tar.gz", hash = "sha256:e70ebf2047a4edf6138cf79c68ad418ef620af65900aa585c5e8bfc95044d43a", size = 48207, upload-time = "2026-05-01T09:09:54.532Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/8e/dc/9085f3d6f497e9b25fb40d6e8ecef3ddbb5cf977a949b933624a299f5c16/pymysql-1.1.3-py3-none-any.whl", hash = "sha256:8164ba62c552f6105f3b11753352d0f16b90d1703ba67d81923d5a8a5d1c5289", size = 45356, upload-time = "2026-05-01T09:09:53.316Z" },
]
[[package]]
name = "pypdfium2"
version = "5.7.1"
-8
View File
@@ -1029,14 +1029,6 @@ run_events:
# client_secret: $DINGTALK_CLIENT_SECRET
# allowed_users: [] # empty = allow all
# card_template_id: "" # Optional: AI Card template ID for streaming updates
#
# discord:
# enabled: false
# bot_token: $DISCORD_BOT_TOKEN
# allowed_guilds: [] # empty = allow all guilds; can also be a single guild ID
# mention_only: false # If true, only respond when the bot is mentioned
# allowed_channels: [] # Optional: channel IDs exempt from mention_only (bot responds without mention)
# thread_mode: false # If true, group a channel conversation into a thread
# ============================================================================
# Guardrails Configuration
+3 -21
View File
@@ -28,10 +28,6 @@ http {
set $gateway_upstream gateway:8001;
set $frontend_upstream frontend:3000;
# Default proxy settings for all locations (streaming/SSE support)
proxy_buffering off;
proxy_cache off;
# Keep the unified nginx endpoint same-origin by default. When split
# frontend/backend or port-forwarded deployments need browser CORS,
# configure the Gateway allowlist with GATEWAY_CORS_ORIGINS so CORS and
@@ -53,6 +49,8 @@ http {
proxy_set_header Connection '';
# SSE/Streaming support
proxy_buffering off;
proxy_cache off;
proxy_set_header X-Accel-Buffering no;
# Timeouts for long-running requests
@@ -72,7 +70,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: Memory endpoint
@@ -83,7 +80,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: MCP configuration endpoint
@@ -94,7 +90,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: Skills configuration endpoint
@@ -105,7 +100,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: Agents endpoint
@@ -116,7 +110,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Custom API: Uploads endpoint
@@ -131,8 +124,6 @@ http {
# Large file upload support
client_max_body_size 100M;
proxy_request_buffering off;
# Disable response buffering to avoid permission errors
}
# Custom API: Other endpoints under /api/threads
@@ -143,7 +134,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# API Documentation: Swagger UI
@@ -154,7 +144,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# API Documentation: ReDoc
@@ -165,7 +154,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# API Documentation: OpenAPI Schema
@@ -176,7 +164,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Health check endpoint (gateway)
@@ -187,7 +174,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# ── Provisioner API (sandbox management) ────────────────────────
@@ -201,7 +187,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
}
# Catch-all for /api/ routes not covered above (e.g. /api/v1/auth/*).
@@ -213,9 +198,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# Disable buffering to avoid permission errors when nginx
# runs as a non-root user (e.g. local development).
}
# All other requests go to frontend
@@ -238,4 +220,4 @@ http {
proxy_read_timeout 600s;
}
}
}
}
-41
View File
@@ -70,11 +70,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
# Disable buffering to avoid permission errors when nginx
# runs as a non-root user (e.g. local development).
proxy_buffering off;
proxy_cache off;
}
# Custom API: Memory endpoint
@@ -85,9 +80,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Custom API: MCP configuration endpoint
@@ -98,9 +90,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Custom API: Skills configuration endpoint
@@ -111,9 +100,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Custom API: Agents endpoint
@@ -124,9 +110,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Custom API: Uploads endpoint
@@ -141,10 +124,6 @@ http {
# Large file upload support
client_max_body_size 100M;
proxy_request_buffering off;
# Disable response buffering to avoid permission errors
proxy_buffering off;
proxy_cache off;
}
# Custom API: Other endpoints under /api/threads
@@ -155,9 +134,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# API Documentation: Swagger UI
@@ -168,9 +144,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# API Documentation: ReDoc
@@ -181,9 +154,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# API Documentation: OpenAPI Schema
@@ -194,9 +164,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Health check endpoint (gateway)
@@ -207,9 +174,6 @@ http {
proxy_set_header X-Real-IP $remote_addr;
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
proxy_set_header X-Forwarded-Proto $scheme;
proxy_buffering off;
proxy_cache off;
}
# Catch-all for any /api/* prefix not matched by a more specific block above.
@@ -229,11 +193,6 @@ http {
# Auth endpoints set HttpOnly cookies — make sure nginx doesn't
# strip the Set-Cookie header from upstream responses.
proxy_pass_header Set-Cookie;
# Disable buffering to avoid permission errors when nginx
# runs as a non-root user (e.g. local development).
proxy_buffering off;
proxy_cache off;
}
# All other requests go to frontend
@@ -66,7 +66,6 @@ export default function AgentChatPage() {
thread,
pendingUsageMessages,
sendMessage,
isUploading,
isHistoryLoading,
hasMoreHistory,
loadMoreHistory,
@@ -107,11 +106,7 @@ export default function AgentChatPage() {
const handleSubmit = useCallback(
(message: PromptInputMessage) => {
const sendPromise = sendMessage(threadId, message, { agent_name });
if (message.files.length > 0) {
return sendPromise;
}
void sendPromise;
void sendMessage(threadId, message, { agent_name });
},
[sendMessage, threadId, agent_name],
);
@@ -248,10 +243,7 @@ export default function AgentChatPage() {
<AgentWelcome agent={agent} agentName={agent_name} />
)
}
disabled={
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||
isUploading
}
disabled={env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true"}
onContextChange={(context) => setSettings("context", context)}
onSubmit={handleSubmit}
onStop={handleStop}
@@ -109,11 +109,7 @@ export default function ChatPage() {
const handleSubmit = useCallback(
(message: PromptInputMessage) => {
const sendPromise = sendMessage(threadId, message);
if (message.files.length > 0) {
return sendPromise;
}
void sendPromise;
void sendMessage(threadId, message);
},
[sendMessage, threadId],
);
@@ -499,10 +499,6 @@ export const PromptInput = ({
// Keep a ref to files for cleanup on unmount (avoids stale closure)
const filesRef = useRef(files);
filesRef.current = files;
const providerTextRef = useRef("");
if (usingProvider) {
providerTextRef.current = controller.textInput.value;
}
const openFileDialogLocal = useCallback(() => {
inputRef.current?.click();
@@ -772,24 +768,6 @@ export const PromptInput = ({
}
// Convert blob URLs to data URLs asynchronously
const submittedFileIds = files.map((file) => file.id);
const clearSubmittedState = () => {
const currentFileIds = new Set(filesRef.current.map((file) => file.id));
const submittedFileIdsStillPresent = submittedFileIds.filter((id) =>
currentFileIds.has(id),
);
if (submittedFileIdsStillPresent.length === filesRef.current.length) {
clear();
} else {
for (const id of submittedFileIdsStillPresent) {
remove(id);
}
}
if (usingProvider && providerTextRef.current === text) {
controller.textInput.clear();
}
};
Promise.all(
files.map(async ({ id, ...item }) => {
if (item.file instanceof File) {
@@ -815,14 +793,20 @@ export const PromptInput = ({
if (result instanceof Promise) {
result
.then(() => {
clearSubmittedState();
clear();
if (usingProvider) {
controller.textInput.clear();
}
})
.catch(() => {
// Don't clear on error - user may want to retry
});
} else {
// Sync function completed without throwing, clear attachments
clearSubmittedState();
clear();
if (usingProvider) {
controller.textInput.clear();
}
}
} catch {
// Don't clear on error - user may want to retry
@@ -110,7 +110,6 @@ export function InputBox({
threadId,
initialValue,
onContextChange,
onFollowupsVisibilityChange,
onSubmit,
onStop,
...props
@@ -143,8 +142,7 @@ export function InputBox({
reasoning_effort?: "minimal" | "low" | "medium" | "high";
},
) => void;
onFollowupsVisibilityChange?: (visible: boolean) => void;
onSubmit?: (message: PromptInputMessage) => void | Promise<void>;
onSubmit?: (message: PromptInputMessage) => void;
onStop?: () => void;
}) {
const { t } = useI18n();
@@ -253,12 +251,12 @@ export function InputBox({
);
const handleSubmit = useCallback(
(message: PromptInputMessage) => {
async (message: PromptInputMessage) => {
if (status === "streaming") {
onStop?.();
return;
}
if (!message.text.trim() && message.files.length === 0) {
if (!message.text) {
return;
}
setFollowups([]);
@@ -276,14 +274,11 @@ export function InputBox({
selectedModel?.supports_thinking ?? false,
),
});
return new Promise<void>((resolve, reject) => {
setTimeout(() => {
Promise.resolve(onSubmit?.(message)).then(resolve).catch(reject);
}, 0);
});
setTimeout(() => onSubmit?.(message), 0);
return;
}
return onSubmit?.(message);
onSubmit?.(message);
},
[
context,
@@ -353,14 +348,6 @@ export function InputBox({
!followupsHidden &&
(followupsLoading || followups.length > 0);
useEffect(() => {
onFollowupsVisibilityChange?.(showFollowups);
}, [onFollowupsVisibilityChange, showFollowups]);
useEffect(() => {
return () => onFollowupsVisibilityChange?.(false);
}, [onFollowupsVisibilityChange]);
useEffect(() => {
messagesRef.current = thread.messages;
}, [thread.messages]);
+6 -9
View File
@@ -26,13 +26,6 @@ export type MessageGroup =
| AssistantClarificationGroup
| AssistantSubagentGroup;
const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([
"summary",
"loop_warning",
"todo_reminder",
"todo_completion_reminder",
]);
export function getMessageGroups(messages: Message[]): MessageGroup[] {
if (messages.length === 0) {
return [];
@@ -60,6 +53,10 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] {
continue;
}
if (message.name === "todo_reminder") {
continue;
}
if (message.type === "human") {
groups.push({ id: message.id, type: "human", messages: [message] });
continue;
@@ -371,8 +368,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
export function isHiddenFromUIMessage(message: Message) {
return (
message.additional_kwargs?.hide_from_ui === true ||
(typeof message.name === "string" &&
HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name))
message.name === "summary" ||
message.name === "loop_warning"
);
}
+41 -134
View File
@@ -45,60 +45,15 @@ type SendMessageOptions = {
additionalKwargs?: Record<string, unknown>;
};
function isNonEmptyString(value: string | undefined): value is string {
return typeof value === "string" && value.length > 0;
}
function messageIdentity(message: Message): string | undefined {
if (
"tool_call_id" in message &&
typeof message.tool_call_id === "string" &&
message.tool_call_id.length > 0
) {
return `tool:${message.tool_call_id}`;
}
if (typeof message.id === "string" && message.id.length > 0) {
return `message:${message.id}`;
}
return undefined;
}
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
const lastIndexByIdentity = new Map<string, number>();
messages.forEach((message, index) => {
const identity = messageIdentity(message);
if (identity) {
lastIndexByIdentity.set(identity, index);
}
});
return messages.filter((message, index) => {
const identity = messageIdentity(message);
return !identity || lastIndexByIdentity.get(identity) === index;
});
}
function findLatestUnloadedRunIndex(
runs: Run[],
loadedRunIds: ReadonlySet<string>,
): number {
for (let i = runs.length - 1; i >= 0; i--) {
const run = runs[i];
if (run && !loadedRunIds.has(run.run_id)) {
return i;
}
}
return -1;
}
export function mergeMessages(
function mergeMessages(
historyMessages: Message[],
threadMessages: Message[],
optimisticMessages: Message[],
): Message[] {
const threadMessageIds = new Set(
threadMessages.map(messageIdentity).filter(isNonEmptyString),
threadMessages
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
.filter(Boolean),
);
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
@@ -110,19 +65,28 @@ export function mergeMessages(
if (!msg) {
continue;
}
const identity = messageIdentity(msg);
if (identity && threadMessageIds.has(identity)) {
if (
(msg?.id && threadMessageIds.has(msg.id)) ||
("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id))
) {
cutoff = i;
} else {
break;
}
}
return dedupeMessagesByIdentity([
return [
...historyMessages.slice(0, cutoff),
...threadMessages,
...optimisticMessages,
]);
];
}
function messageIdentity(message: Message): string | undefined {
if ("tool_call_id" in message) {
return message.tool_call_id;
}
return message.id;
}
function getMessagesAfterBaseline(
@@ -663,105 +627,48 @@ export function useThreadHistory(threadId: string) {
const runsRef = useRef(runs.data ?? []);
const indexRef = useRef(-1);
const loadingRef = useRef(false);
const pendingLoadRef = useRef(false);
const loadingRunIdRef = useRef<string | null>(null);
const loadedRunIdsRef = useRef<Set<string>>(new Set());
const [loading, setLoading] = useState(false);
const [messages, setMessages] = useState<Message[]>([]);
loadingRef.current = loading;
const loadMessages = useCallback(async () => {
if (loadingRef.current) {
const pendingRunIndex = findLatestUnloadedRunIndex(
runsRef.current,
loadedRunIdsRef.current,
);
const pendingRun = runsRef.current[pendingRunIndex];
if (pendingRun && pendingRun.run_id !== loadingRunIdRef.current) {
pendingLoadRef.current = true;
}
return;
}
if (runsRef.current.length === 0) {
return;
}
loadingRef.current = true;
setLoading(true);
const run = runsRef.current[indexRef.current];
if (!run || loadingRef.current) {
return;
}
try {
do {
pendingLoadRef.current = false;
const nextRunIndex = findLatestUnloadedRunIndex(
runsRef.current,
loadedRunIdsRef.current,
);
indexRef.current = nextRunIndex;
const run = runsRef.current[nextRunIndex];
if (!run) {
indexRef.current = -1;
return;
}
const requestThreadId = threadIdRef.current;
loadingRunIdRef.current = run.run_id;
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(requestThreadId)}/runs/${encodeURIComponent(run.run_id)}/messages`,
{
method: "GET",
headers: {
"Content-Type": "application/json",
},
credentials: "include",
setLoading(true);
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`,
{
method: "GET",
headers: {
"Content-Type": "application/json",
},
).then((res) => {
return res.json();
});
const _messages = result.data
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
.map((m) => m.content);
if (threadIdRef.current !== requestThreadId) {
return;
}
setMessages((prev) =>
dedupeMessagesByIdentity([..._messages, ...prev]),
);
loadedRunIdsRef.current.add(run.run_id);
indexRef.current = findLatestUnloadedRunIndex(
runsRef.current,
loadedRunIdsRef.current,
);
} while (pendingLoadRef.current);
credentials: "include",
},
).then((res) => {
return res.json();
});
const _messages = result.data
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
.map((m) => m.content);
setMessages((prev) => [..._messages, ...prev]);
indexRef.current -= 1;
} catch (err) {
console.error(err);
} finally {
loadingRef.current = false;
loadingRunIdRef.current = null;
setLoading(false);
}
}, []);
useEffect(() => {
const threadChanged = threadIdRef.current !== threadId;
threadIdRef.current = threadId;
if (threadChanged) {
runsRef.current = [];
indexRef.current = -1;
pendingLoadRef.current = false;
loadingRunIdRef.current = null;
loadedRunIdsRef.current = new Set();
loadingRef.current = false;
setLoading(false);
setMessages([]);
}
if (runs.data && runs.data.length > 0) {
runsRef.current = runs.data ?? [];
indexRef.current = findLatestUnloadedRunIndex(
runs.data,
loadedRunIdsRef.current,
);
indexRef.current = runs.data.length - 1;
}
loadMessages().catch(() => {
toast.error("Failed to load thread history.");
@@ -770,7 +677,7 @@ export function useThreadHistory(threadId: string) {
const appendMessages = useCallback((_messages: Message[]) => {
setMessages((prev) => {
return dedupeMessagesByIdentity([...prev, ..._messages]);
return [...prev, ..._messages];
});
}, []);
const hasMore = indexRef.current >= 0 || !runs.data;
-62
View File
@@ -48,66 +48,4 @@ test.describe("Chat workspace", () => {
timeout: 10_000,
});
});
test("keeps attachments visible while upload submit is pending", async ({
page,
}) => {
let releaseUpload!: () => void;
const uploadCanFinish = new Promise<void>((resolve) => {
releaseUpload = resolve;
});
let uploadStarted!: () => void;
const uploadStartedPromise = new Promise<void>((resolve) => {
uploadStarted = resolve;
});
await page.route("**/api/threads/*/uploads", async (route) => {
uploadStarted();
await uploadCanFinish;
return route.fulfill({
status: 200,
contentType: "application/json",
body: JSON.stringify({
success: true,
message: "Uploaded",
files: [
{
filename: "report.docx",
size: 12,
path: "report.docx",
virtual_path: "/mnt/user-data/uploads/report.docx",
artifact_url: "/api/threads/test/uploads/report.docx",
extension: ".docx",
},
],
}),
});
});
await page.goto("/workspace/chats/new");
const textarea = page.getByPlaceholder(/how can i assist you/i);
await expect(textarea).toBeVisible({ timeout: 15_000 });
const promptForm = page.locator("form").filter({ has: textarea });
await page.getByLabel("Upload files").setInputFiles({
name: "report.docx",
mimeType:
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
buffer: Buffer.from("fake docx"),
});
await expect(promptForm.getByText("report.docx")).toBeVisible();
await textarea.fill("Summarize this document");
await textarea.press("Enter");
await uploadStartedPromise;
await expect(promptForm.getByText("report.docx")).toBeVisible();
releaseUpload();
await expect(page.getByText("Hello from DeerFlow!")).toBeVisible({
timeout: 10_000,
});
await expect(promptForm.getByText("report.docx")).toBeHidden();
});
});
@@ -63,37 +63,3 @@ test("aggregates token usage messages once per assistant turn", () => {
),
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
});
test("hides internal todo reminder messages from message groups", () => {
const messages = [
{
id: "human-1",
type: "human",
content: "Audit the middleware",
},
{
id: "todo-reminder-1",
type: "human",
name: "todo_completion_reminder",
content: "<system_reminder>finish todos</system_reminder>",
},
{
id: "todo-reminder-2",
type: "human",
name: "todo_reminder",
content: "<system_reminder>remember todos</system_reminder>",
},
{
id: "ai-1",
type: "ai",
content: "Done",
},
] as Message[];
const groups = getMessageGroups(messages);
expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]);
expect(
groups.flatMap((group) => group.messages).map((message) => message.id),
).toEqual(["human-1", "ai-1"]);
});
@@ -1,64 +0,0 @@
import type { Message } from "@langchain/langgraph-sdk";
import { expect, test } from "vitest";
import { mergeMessages } from "@/core/threads/hooks";
test("mergeMessages removes duplicate messages already present in history", () => {
const human = {
id: "human-1",
type: "human",
content: "Design an agent",
} as Message;
const ai = {
id: "ai-1",
type: "ai",
content: "Let's design it.",
} as Message;
expect(mergeMessages([human, ai, human, ai], [], [])).toEqual([human, ai]);
});
test("mergeMessages lets live thread messages replace overlapping history", () => {
const oldHuman = {
id: "human-1",
type: "human",
content: "old",
} as Message;
const liveHuman = {
id: "human-1",
type: "human",
content: "live",
} as Message;
const oldAi = {
id: "ai-1",
type: "ai",
content: "old",
} as Message;
const liveAi = {
id: "ai-1",
type: "ai",
content: "live",
} as Message;
expect(mergeMessages([oldHuman, oldAi], [liveHuman, liveAi], [])).toEqual([
liveHuman,
liveAi,
]);
});
test("mergeMessages deduplicates tool messages by tool_call_id", () => {
const oldTool = {
id: "tool-message-old",
type: "tool",
tool_call_id: "call-1",
content: "old",
} as Message;
const liveTool = {
id: "tool-message-live",
type: "tool",
tool_call_id: "call-1",
content: "live",
} as Message;
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
});
-81
View File
@@ -72,7 +72,6 @@ def find_config_file() -> Path | None:
_SECTION_RE = re.compile(r"^([A-Za-z_][\w-]*)\s*:\s*$")
_INDENTED_SECTION_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*$")
_KEY_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*(\S.*?)\s*$")
@@ -142,84 +141,6 @@ def section_value(lines: list[str], section: str, key: str) -> str | None:
return None
def nested_section_value(lines: list[str], section_path: str, key: str) -> str | None:
"""Return the value of a nested YAML key like ``channels.discord.enabled``.
Handles two levels of nesting:
channels:
discord:
enabled: true
"""
parts = section_path.split(".")
if len(parts) != 2:
return None
parent_section, child_section = parts
inside_parent = False
inside_child = False
parent_indent: int | None = None
child_indent: int | None = None
for raw in lines:
line = _strip_comment(raw)
if not line.strip():
continue
stripped = line.lstrip()
indent = len(line) - len(stripped)
# Top-level section match
sect_match = _SECTION_RE.match(line)
if sect_match:
if indent == 0:
inside_parent = sect_match.group(1) == parent_section
inside_child = False
parent_indent = None
child_indent = None
continue
if not inside_parent:
continue
# Track parent indent from first child
if parent_indent is None and indent > 0:
parent_indent = indent
# If indent goes back to 0, we left the parent section
if indent == 0:
inside_parent = False
inside_child = False
continue
# Check if we're at the parent's child level (subsection)
if parent_indent is not None and indent == parent_indent:
# This could be a subsection or a direct key of parent
sub_match = _INDENTED_SECTION_RE.match(line)
if sub_match and sub_match.group(1) == child_section:
inside_child = True
child_indent = None
continue
else:
inside_child = False
continue
if not inside_child:
continue
# We're inside the subsection — track child indent
if child_indent is None and indent > (parent_indent or 0):
child_indent = indent
if child_indent is not None and indent != child_indent:
continue
key_match = _KEY_RE.match(line)
if key_match and key_match.group(1) == key:
return _unquote(key_match.group(2).strip())
return None
def detect_from_config(path: Path) -> list[str]:
try:
text = path.read_text(encoding="utf-8", errors="replace")
@@ -231,8 +152,6 @@ def detect_from_config(path: Path) -> list[str]:
extras.add("postgres")
if (section_value(lines, "checkpointer", "type") or "").lower() == "postgres":
extras.add("postgres")
if (nested_section_value(lines, "channels.discord", "enabled") or "").lower() == "true":
extras.add("discord")
return sorted(extras)