Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d7a2fff7e0 | |||
| eabd78ce4e | |||
| 533d3fbfee | |||
| d6b3a277a5 | |||
| def2a3ad79 | |||
| 3c0b42d836 | |||
| 34ec205e1d | |||
| 11a9041b65 | |||
| d3066a1746 | |||
| 485f8a2bf2 |
+11
-291
@@ -3,10 +3,8 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
@@ -23,12 +21,6 @@ class DiscordChannel(Channel):
|
||||
Configuration keys (in ``config.yaml`` under ``channels.discord``):
|
||||
- ``bot_token``: Discord Bot token.
|
||||
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
|
||||
- ``mention_only``: (optional) If true, only respond when the bot is mentioned.
|
||||
- ``allowed_channels``: (optional) List of channel IDs where messages are always accepted
|
||||
(even when mention_only is true). Use for channels where you want the bot to respond
|
||||
without mentions. Empty = mention_only applies everywhere.
|
||||
- ``thread_mode``: (optional) If true, group a channel conversation into a thread.
|
||||
Default: same as ``mention_only``.
|
||||
"""
|
||||
|
||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
||||
@@ -40,29 +32,6 @@ class DiscordChannel(Channel):
|
||||
self._allowed_guilds.add(int(guild_id))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
self._mention_only: bool = bool(config.get("mention_only", False))
|
||||
self._thread_mode: bool = config.get("thread_mode", self._mention_only)
|
||||
self._allowed_channels: set[str] = set()
|
||||
for channel_id in config.get("allowed_channels", []):
|
||||
self._allowed_channels.add(str(channel_id))
|
||||
|
||||
# Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON).
|
||||
# Uses a dedicated JSON file separate from ChannelStore, which maps IM
|
||||
# conversations to DeerFlow thread IDs — a different concern.
|
||||
self._active_threads: dict[str, str] = {}
|
||||
# Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()).
|
||||
self._active_thread_ids: set[str] = set()
|
||||
# Lock protecting _active_threads and the JSON file from concurrent access.
|
||||
# _run_client (Discord loop thread) and the main thread both read/write.
|
||||
self._thread_store_lock = threading.Lock()
|
||||
store = config.get("channel_store")
|
||||
if store is not None:
|
||||
self._thread_store_path = store._path.parent / "discord_threads.json"
|
||||
else:
|
||||
self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json"
|
||||
|
||||
# Typing indicator management
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
self._client = None
|
||||
self._thread: threading.Thread | None = None
|
||||
@@ -106,56 +75,12 @@ class DiscordChannel(Channel):
|
||||
|
||||
self._thread = threading.Thread(target=self._run_client, daemon=True)
|
||||
self._thread.start()
|
||||
self._load_active_threads()
|
||||
logger.info("Discord channel started")
|
||||
|
||||
def _load_active_threads(self) -> None:
|
||||
"""Restore Discord thread mappings from the dedicated JSON file on startup."""
|
||||
with self._thread_store_lock:
|
||||
try:
|
||||
if not self._thread_store_path.exists():
|
||||
logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path)
|
||||
return
|
||||
data = json.loads(self._thread_store_path.read_text())
|
||||
self._active_threads.clear()
|
||||
self._active_thread_ids.clear()
|
||||
for channel_id, thread_id in data.items():
|
||||
self._active_threads[channel_id] = thread_id
|
||||
self._active_thread_ids.add(thread_id)
|
||||
if self._active_threads:
|
||||
logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path)
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to load thread mappings")
|
||||
|
||||
def _save_thread(self, channel_id: str, thread_id: str) -> None:
|
||||
"""Persist a Discord thread mapping to the dedicated JSON file."""
|
||||
with self._thread_store_lock:
|
||||
try:
|
||||
data: dict[str, str] = {}
|
||||
if self._thread_store_path.exists():
|
||||
data = json.loads(self._thread_store_path.read_text())
|
||||
old_id = data.get(channel_id)
|
||||
data[channel_id] = thread_id
|
||||
# Update reverse-lookup set
|
||||
if old_id:
|
||||
self._active_thread_ids.discard(old_id)
|
||||
self._active_thread_ids.add(thread_id)
|
||||
self._thread_store_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._thread_store_path.write_text(json.dumps(data, indent=2))
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
||||
|
||||
# Cancel all active typing indicator tasks
|
||||
for target_id, task in list(self._typing_tasks.items()):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.debug("[Discord] cancelled typing task for target %s", target_id)
|
||||
self._typing_tasks.clear()
|
||||
|
||||
if self._client and self._discord_loop and self._discord_loop.is_running():
|
||||
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
|
||||
try:
|
||||
@@ -175,10 +100,6 @@ class DiscordChannel(Channel):
|
||||
logger.info("Discord channel stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
# Stop typing indicator once we're sending the response
|
||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
||||
await asyncio.wrap_future(stop_future)
|
||||
|
||||
target = await self._resolve_target(msg)
|
||||
if target is None:
|
||||
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
||||
@@ -190,9 +111,6 @@ class DiscordChannel(Channel):
|
||||
await asyncio.wrap_future(send_future)
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
||||
await asyncio.wrap_future(stop_future)
|
||||
|
||||
target = await self._resolve_target(msg)
|
||||
if target is None:
|
||||
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
||||
@@ -212,41 +130,6 @@ class DiscordChannel(Channel):
|
||||
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None:
|
||||
"""Starts a loop to send periodic typing indicators."""
|
||||
target_id = thread_ts or chat_id
|
||||
if target_id in self._typing_tasks:
|
||||
return # Already typing for this target
|
||||
|
||||
async def _typing_loop():
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await channel.trigger_typing()
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
task = asyncio.create_task(_typing_loop())
|
||||
self._typing_tasks[target_id] = task
|
||||
|
||||
async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None:
|
||||
"""Stops the typing loop for a specific target."""
|
||||
target_id = thread_ts or chat_id
|
||||
task = self._typing_tasks.pop(target_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
logger.debug("[Discord] stopped typing indicator for target %s", target_id)
|
||||
|
||||
async def _add_reaction(self, message) -> None:
|
||||
"""Add a checkmark reaction to acknowledge the message was received."""
|
||||
try:
|
||||
await message.add_reaction("✅")
|
||||
except Exception:
|
||||
logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True)
|
||||
|
||||
async def _on_message(self, message) -> None:
|
||||
if not self._running or not self._client:
|
||||
return
|
||||
@@ -269,143 +152,15 @@ class DiscordChannel(Channel):
|
||||
if self._discord_module is None:
|
||||
return
|
||||
|
||||
# Determine whether the bot is mentioned in this message
|
||||
user = self._client.user if self._client else None
|
||||
if user:
|
||||
bot_mention = user.mention # <@ID>
|
||||
alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant)
|
||||
standard_mention = f"<@{user.id}>"
|
||||
else:
|
||||
bot_mention = None
|
||||
alt_mention = None
|
||||
standard_mention = ""
|
||||
has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content)
|
||||
|
||||
# Strip mention from text for processing
|
||||
if has_mention:
|
||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||
|
||||
# --- Determine thread/channel routing and typing target ---
|
||||
thread_id = None
|
||||
chat_id = None
|
||||
typing_target = None # The Discord object to type into
|
||||
|
||||
if isinstance(message.channel, self._discord_module.Thread):
|
||||
# --- Message already inside a thread ---
|
||||
thread_obj = message.channel
|
||||
thread_id = str(thread_obj.id)
|
||||
chat_id = str(thread_obj.parent_id or thread_obj.id)
|
||||
typing_target = thread_obj
|
||||
|
||||
# If this is a known active thread, process normally
|
||||
if thread_id in self._active_thread_ids:
|
||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
chat_id=chat_id,
|
||||
user_id=str(message.author.id),
|
||||
text=text,
|
||||
msg_type=msg_type,
|
||||
thread_ts=thread_id,
|
||||
metadata={
|
||||
"guild_id": str(guild.id) if guild else None,
|
||||
"channel_id": str(message.channel.id),
|
||||
"message_id": str(message.id),
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
self._publish(inbound)
|
||||
# Start typing indicator in the thread
|
||||
if typing_target:
|
||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
||||
asyncio.create_task(self._add_reaction(message))
|
||||
return
|
||||
|
||||
# Thread not tracked (orphaned) — create new thread and handle below
|
||||
logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id)
|
||||
thread_id = None
|
||||
typing_target = None
|
||||
|
||||
# At this point we're guaranteed to be in a channel, not a thread
|
||||
# (the Thread case is handled above). Apply mention_only for all
|
||||
# non-thread messages — no special case needed.
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
# Check if there's an active thread for this channel
|
||||
if channel_id in self._active_threads:
|
||||
# respect mention_only: if enabled, only process messages that mention the bot
|
||||
# (unless the channel is in allowed_channels)
|
||||
# Messages within a thread are always allowed through (continuation).
|
||||
# At this code point we know the message is in a channel, not a thread
|
||||
# (Thread case handled above), so always apply the check.
|
||||
if self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
||||
logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id)
|
||||
return
|
||||
# mention_only + fresh @ → create new thread instead of routing to existing one
|
||||
if self._mention_only and has_mention:
|
||||
thread_obj = await self._create_thread(message)
|
||||
if thread_obj is not None:
|
||||
target_thread_id = str(thread_obj.id)
|
||||
self._active_threads[channel_id] = target_thread_id
|
||||
self._save_thread(channel_id, target_thread_id)
|
||||
thread_id = target_thread_id
|
||||
chat_id = channel_id
|
||||
typing_target = thread_obj
|
||||
logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id)
|
||||
else:
|
||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
||||
thread_id = channel_id
|
||||
chat_id = channel_id
|
||||
typing_target = message.channel
|
||||
else:
|
||||
# Existing session → route to the existing thread
|
||||
target_thread_id = self._active_threads[channel_id]
|
||||
logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id)
|
||||
thread_id = target_thread_id
|
||||
chat_id = channel_id
|
||||
typing_target = await self._get_channel_or_thread(target_thread_id)
|
||||
elif self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
||||
# Not mentioned and not in an allowed channel → skip
|
||||
logger.debug("[Discord] skipping message without mention in channel %s", channel_id)
|
||||
return
|
||||
elif self._mention_only and has_mention:
|
||||
# First mention in this channel → create thread
|
||||
thread_obj = await self._create_thread(message)
|
||||
if thread_obj is not None:
|
||||
target_thread_id = str(thread_obj.id)
|
||||
self._active_threads[channel_id] = target_thread_id
|
||||
self._save_thread(channel_id, target_thread_id)
|
||||
thread_id = target_thread_id
|
||||
chat_id = channel_id
|
||||
typing_target = thread_obj # Type into the new thread
|
||||
logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name)
|
||||
else:
|
||||
# Fallback: thread creation failed (disabled/permissions), reply in channel
|
||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
||||
thread_id = channel_id
|
||||
chat_id = channel_id
|
||||
typing_target = message.channel # Type into the channel
|
||||
elif self._thread_mode:
|
||||
# thread_mode but mention_only is False → create thread anyway for conversation grouping
|
||||
thread_obj = await self._create_thread(message)
|
||||
if thread_obj is None:
|
||||
# Thread creation failed (disabled/permissions), fall back to channel replies
|
||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
||||
thread_id = channel_id
|
||||
chat_id = channel_id
|
||||
typing_target = message.channel # Type into the channel
|
||||
else:
|
||||
target_thread_id = str(thread_obj.id)
|
||||
self._active_threads[channel_id] = target_thread_id
|
||||
self._save_thread(channel_id, target_thread_id)
|
||||
thread_id = target_thread_id
|
||||
chat_id = channel_id
|
||||
typing_target = thread_obj # Type into the new thread
|
||||
chat_id = str(message.channel.parent_id or message.channel.id)
|
||||
thread_id = str(message.channel.id)
|
||||
else:
|
||||
# No threading — reply directly in channel
|
||||
thread_id = channel_id
|
||||
chat_id = channel_id
|
||||
typing_target = message.channel # Type into the channel
|
||||
thread = await self._create_thread(message)
|
||||
if thread is None:
|
||||
return
|
||||
chat_id = str(message.channel.id)
|
||||
thread_id = str(thread.id)
|
||||
|
||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
@@ -422,15 +177,6 @@ class DiscordChannel(Channel):
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
|
||||
# Start typing indicator in the correct target (thread or channel)
|
||||
if typing_target:
|
||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
||||
|
||||
self._publish(inbound)
|
||||
asyncio.create_task(self._add_reaction(message))
|
||||
|
||||
def _publish(self, inbound) -> None:
|
||||
"""Publish an inbound message to the main event loop."""
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||
@@ -452,40 +198,14 @@ class DiscordChannel(Channel):
|
||||
|
||||
async def _create_thread(self, message):
|
||||
try:
|
||||
if self._discord_module is None:
|
||||
return None
|
||||
|
||||
# Only TextChannel (type 0) and NewsChannel (type 10) support threads
|
||||
channel_type = message.channel.type
|
||||
if channel_type not in (
|
||||
self._discord_module.ChannelType.text,
|
||||
self._discord_module.ChannelType.news,
|
||||
):
|
||||
logger.info(
|
||||
"[Discord] channel type %s (%s) does not support threads",
|
||||
channel_type.value,
|
||||
channel_type.name,
|
||||
)
|
||||
return None
|
||||
|
||||
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
|
||||
return await message.create_thread(name=thread_name)
|
||||
except self._discord_module.errors.HTTPException as exc:
|
||||
if exc.code == 50024:
|
||||
logger.info(
|
||||
"[Discord] cannot create thread in channel %s (error code 50024): %s",
|
||||
message.channel.id,
|
||||
channel_type.name if (channel_type := message.channel.type) else "unknown",
|
||||
)
|
||||
else:
|
||||
logger.exception(
|
||||
"[Discord] failed to create thread for message=%s (HTTPException %s)",
|
||||
message.id,
|
||||
exc.code,
|
||||
)
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
|
||||
try:
|
||||
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def _resolve_target(self, msg: OutboundMessage):
|
||||
|
||||
@@ -787,22 +787,13 @@ class ChannelManager:
|
||||
return
|
||||
|
||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||
try:
|
||||
result = await client.runs.wait(
|
||||
thread_id,
|
||||
assistant_id,
|
||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||
config=run_config,
|
||||
context=run_context,
|
||||
multitask_strategy="reject",
|
||||
)
|
||||
except Exception as exc:
|
||||
if _is_thread_busy_error(exc):
|
||||
logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id)
|
||||
await self._send_error(msg, THREAD_BUSY_MESSAGE)
|
||||
return
|
||||
else:
|
||||
raise
|
||||
result = await client.runs.wait(
|
||||
thread_id,
|
||||
assistant_id,
|
||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||
config=run_config,
|
||||
context=run_context,
|
||||
)
|
||||
|
||||
response_text = _extract_response_text(result)
|
||||
artifacts = _extract_artifacts(result)
|
||||
|
||||
@@ -167,8 +167,6 @@ class ChannelService:
|
||||
return False
|
||||
|
||||
try:
|
||||
config = dict(config)
|
||||
config["channel_store"] = self.store
|
||||
channel = channel_cls(bus=self.bus, config=config)
|
||||
self._channels[name] = channel
|
||||
await channel.start()
|
||||
|
||||
@@ -8,8 +8,6 @@ from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SECRET_FILE = ".jwt_secret"
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""JWT and auth-related configuration. Parsed once at startup.
|
||||
@@ -32,32 +30,6 @@ class AuthConfig(BaseModel):
|
||||
_auth_config: AuthConfig | None = None
|
||||
|
||||
|
||||
def _load_or_create_secret() -> str:
|
||||
"""Load persisted JWT secret from ``{base_dir}/.jwt_secret``, or generate and persist a new one."""
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
paths = get_paths()
|
||||
secret_file = paths.base_dir / _SECRET_FILE
|
||||
|
||||
try:
|
||||
if secret_file.exists():
|
||||
secret = secret_file.read_text(encoding="utf-8").strip()
|
||||
if secret:
|
||||
return secret
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Failed to read JWT secret from {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can read its persisted auth secret.") from exc
|
||||
|
||||
secret = secrets.token_urlsafe(32)
|
||||
try:
|
||||
secret_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = os.open(secret_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
fh.write(secret)
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Failed to persist JWT secret to {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can store a stable auth secret.") from exc
|
||||
return secret
|
||||
|
||||
|
||||
def get_auth_config() -> AuthConfig:
|
||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
||||
global _auth_config
|
||||
@@ -67,11 +39,11 @@ def get_auth_config() -> AuthConfig:
|
||||
load_dotenv()
|
||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||
if not jwt_secret:
|
||||
jwt_secret = _load_or_create_secret()
|
||||
jwt_secret = secrets.token_urlsafe(32)
|
||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||
logger.warning(
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated secret "
|
||||
"persisted to .jwt_secret. Sessions will survive restarts. "
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||
"Sessions will be invalidated on restart. "
|
||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
|
||||
@@ -20,9 +20,6 @@ ACTIVE_CONTENT_MIME_TYPES = {
|
||||
"image/svg+xml",
|
||||
}
|
||||
|
||||
MAX_SKILL_ARCHIVE_MEMBER_BYTES = 16 * 1024 * 1024
|
||||
_SKILL_ARCHIVE_READ_CHUNK_SIZE = 64 * 1024
|
||||
|
||||
|
||||
def _build_content_disposition(disposition_type: str, filename: str) -> str:
|
||||
"""Build an RFC 5987 encoded Content-Disposition header value."""
|
||||
@@ -47,22 +44,6 @@ def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _read_skill_archive_member(zip_ref: zipfile.ZipFile, info: zipfile.ZipInfo) -> bytes:
|
||||
"""Read a .skill archive member while enforcing an uncompressed size cap."""
|
||||
if info.file_size > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
||||
|
||||
chunks: list[bytes] = []
|
||||
total_read = 0
|
||||
with zip_ref.open(info, "r") as src:
|
||||
while chunk := src.read(_SKILL_ARCHIVE_READ_CHUNK_SIZE):
|
||||
total_read += len(chunk)
|
||||
if total_read > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
|
||||
"""Extract a file from a .skill ZIP archive.
|
||||
|
||||
@@ -79,16 +60,16 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
# List all files in the archive
|
||||
infos_by_name = {info.filename: info for info in zip_ref.infolist()}
|
||||
namelist = zip_ref.namelist()
|
||||
|
||||
# Try direct path first
|
||||
if internal_path in infos_by_name:
|
||||
return _read_skill_archive_member(zip_ref, infos_by_name[internal_path])
|
||||
if internal_path in namelist:
|
||||
return zip_ref.read(internal_path)
|
||||
|
||||
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
|
||||
for name, info in infos_by_name.items():
|
||||
for name in namelist:
|
||||
if name.endswith("/" + internal_path) or name == internal_path:
|
||||
return _read_skill_archive_member(zip_ref, info)
|
||||
return zip_ref.read(name)
|
||||
|
||||
# Not found
|
||||
return None
|
||||
|
||||
@@ -99,7 +99,7 @@ rm -f backend/.deer-flow/data/deerflow.db
|
||||
| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory |
|
||||
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
|
||||
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
|
||||
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成并持久化到 `.deer-flow/.jwt_secret`,重启后 session 保持) |
|
||||
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
|
||||
|
||||
### 生产环境建议
|
||||
|
||||
@@ -137,4 +137,4 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` |
|
||||
| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin |
|
||||
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
|
||||
| 重启后需要重新登录 | `.jwt_secret` 文件被删除且 `.env` 未设置 `AUTH_JWT_SECRET` | 在 `.env` 中设置固定密钥 |
|
||||
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
|
||||
|
||||
@@ -0,0 +1,401 @@
|
||||
# Storage Package Design
|
||||
|
||||
## Background
|
||||
|
||||
DeerFlow currently has several persistence responsibilities spread across app, gateway, runtime, and legacy persistence modules. This makes the persistence boundary difficult to reason about and creates several migration risks:
|
||||
|
||||
- Routers and runtime services can accidentally depend on concrete persistence implementations instead of stable contracts.
|
||||
- User/auth, run metadata, thread metadata, feedback, run events, and checkpointer setup are initialized through different paths.
|
||||
- Some persistence behavior is duplicated between memory, SQLite, and PostgreSQL-oriented code paths.
|
||||
- Incremental migration is hard because app-level code and storage-level code are coupled.
|
||||
- Adding or validating another SQL backend requires touching app/runtime code instead of a storage-owned package.
|
||||
|
||||
The storage package is introduced to make application data persistence a package-level capability with explicit contracts, a clear boundary, and SQL backend compatibility.
|
||||
|
||||
## Goals
|
||||
|
||||
- Provide a standalone `packages/storage` package for durable application data.
|
||||
- Support SQLite, PostgreSQL, and MySQL through a shared persistence construction flow.
|
||||
- Keep LangGraph checkpointer initialization compatible with the same database backend.
|
||||
- Expose repository contracts as the only package-level data access boundary.
|
||||
- Let the app layer depend on app-owned adapters under `app.infra.storage`, not on storage DB implementation classes.
|
||||
- Allow the app/gateway migration to happen in small steps without forcing a large rewrite.
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- This design does not remove legacy persistence in the first PR.
|
||||
- This design does not move routers directly onto storage package models.
|
||||
- This design does not make app routers own SQLAlchemy sessions.
|
||||
- Cron persistence is intentionally out of scope for the storage package foundation.
|
||||
- Memory backend is not part of the durable storage package. Memory compatibility, if still needed by app runtime, belongs outside `packages/storage`.
|
||||
|
||||
## Storage Design Principles
|
||||
|
||||
### Package-Owned Durable Storage
|
||||
|
||||
`packages/storage` owns durable application data persistence. It defines:
|
||||
|
||||
- configuration shape for storage-backed persistence
|
||||
- SQLAlchemy models
|
||||
- repository contracts and DTOs
|
||||
- SQL repository implementations
|
||||
- persistence factory functions
|
||||
- compatibility helpers for config-driven initialization
|
||||
|
||||
The package should be usable without importing `app.gateway`, routers, auth providers, or runtime-specific gateway objects.
|
||||
|
||||
### SQL Backend Compatibility
|
||||
|
||||
The package supports three SQL backends:
|
||||
|
||||
- SQLite for local/single-node deployments
|
||||
- PostgreSQL for production multi-node deployments
|
||||
- MySQL for deployments that standardize on MySQL
|
||||
|
||||
Backend-specific differences are handled inside the storage package:
|
||||
|
||||
- SQLAlchemy async engine URL construction
|
||||
- LangGraph checkpointer connection-string compatibility
|
||||
- JSON metadata filtering across SQLite/PostgreSQL/MySQL
|
||||
- SQL dialect behavior around locking, aggregation, and JSON type semantics
|
||||
|
||||
### Unified Persistence Bundle
|
||||
|
||||
Storage initialization returns an `AppPersistence` bundle:
|
||||
|
||||
```python
|
||||
@dataclass(slots=True)
|
||||
class AppPersistence:
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
setup: Callable[[], Awaitable[None]]
|
||||
aclose: Callable[[], Awaitable[None]]
|
||||
```
|
||||
|
||||
The app runtime can initialize persistence once, call `setup()`, and then inject:
|
||||
|
||||
- `checkpointer`
|
||||
- `session_factory`
|
||||
- repository adapters
|
||||
|
||||
This keeps checkpointer and application data aligned to the same backend without requiring routers to understand database configuration.
|
||||
|
||||
## Package Layout
|
||||
|
||||
```text
|
||||
backend/packages/storage/
|
||||
store/
|
||||
config/
|
||||
storage_config.py
|
||||
app_config.py
|
||||
persistence/
|
||||
factory.py
|
||||
types.py
|
||||
base_model.py
|
||||
json_compat.py
|
||||
drivers/
|
||||
sqlite.py
|
||||
postgres.py
|
||||
mysql.py
|
||||
repositories/
|
||||
contracts/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
models/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
db/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
factory.py
|
||||
```
|
||||
|
||||
## Persistence Construction
|
||||
|
||||
The primary storage entrypoint is:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_storage_config
|
||||
|
||||
persistence = await create_persistence_from_storage_config(storage_config)
|
||||
await persistence.setup()
|
||||
```
|
||||
|
||||
For app-level compatibility with existing database config shape:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
|
||||
persistence = await create_persistence_from_database_config(config.database)
|
||||
await persistence.setup()
|
||||
```
|
||||
|
||||
Expected app startup flow:
|
||||
|
||||
```python
|
||||
persistence = await create_persistence_from_database_config(config.database)
|
||||
await persistence.setup()
|
||||
|
||||
app.state.persistence = persistence
|
||||
app.state.checkpointer = persistence.checkpointer
|
||||
app.state.session_factory = persistence.session_factory
|
||||
```
|
||||
|
||||
Expected app shutdown flow:
|
||||
|
||||
```python
|
||||
await app.state.persistence.aclose()
|
||||
```
|
||||
|
||||
## Repository Contract Design
|
||||
|
||||
Repository contracts are the storage package's public data access boundary. They live under `store.repositories.contracts` and are re-exported from `store.repositories`.
|
||||
|
||||
The key contract groups are:
|
||||
|
||||
- `UserRepositoryProtocol`
|
||||
- `RunRepositoryProtocol`
|
||||
- `ThreadMetaRepositoryProtocol`
|
||||
- `FeedbackRepositoryProtocol`
|
||||
- `RunEventRepositoryProtocol`
|
||||
|
||||
Each contract owns:
|
||||
|
||||
- input DTOs, such as `UserCreate`, `RunCreate`, `ThreadMetaCreate`
|
||||
- output DTOs, such as `User`, `Run`, `ThreadMeta`
|
||||
- repository protocol methods
|
||||
- domain-specific exceptions when needed, such as `InvalidMetadataFilterError`
|
||||
|
||||
Repository construction is session-based:
|
||||
|
||||
```python
|
||||
from store.repositories import build_run_repository
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
run = await repo.get_run(run_id)
|
||||
```
|
||||
|
||||
This keeps transaction ownership explicit. The storage package does not hide commits or session lifecycle inside global singletons.
|
||||
|
||||
## App/Infra Calling Contract
|
||||
|
||||
The app layer should not call `store.repositories.db.*` directly. The intended app boundary is `app.infra.storage`.
|
||||
|
||||
`app.infra.storage` is responsible for:
|
||||
|
||||
- receiving `session_factory` from FastAPI runtime initialization
|
||||
- owning session lifecycle for app-facing repository methods
|
||||
- translating storage DTOs to app/gateway DTOs only when needed
|
||||
- preserving the existing app-facing names during migration
|
||||
- depending on storage repository protocols, not concrete DB classes
|
||||
|
||||
Expected adapter pattern:
|
||||
|
||||
```python
|
||||
class StorageRunRepository(RunRepositoryProtocol):
|
||||
def __init__(self, session_factory):
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def get_run(self, run_id: str):
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
return await repo.get_run(run_id)
|
||||
```
|
||||
|
||||
For gateway compatibility, app state can keep existing names while the implementation changes:
|
||||
|
||||
```python
|
||||
app.state.run_store = StorageRunStore(run_repository)
|
||||
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
|
||||
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
|
||||
app.state.run_event_store = StorageRunEventStore(run_event_repository)
|
||||
app.state.checkpointer = persistence.checkpointer
|
||||
app.state.session_factory = persistence.session_factory
|
||||
```
|
||||
|
||||
The app-facing objects may expose legacy method names during migration, but their internal data access should go through storage contracts.
|
||||
|
||||
## Boundary Rules
|
||||
|
||||
### Allowed Calls
|
||||
|
||||
Storage package callers may use:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.persistence import create_persistence_from_storage_config
|
||||
from store.repositories import build_run_repository
|
||||
from store.repositories import build_user_repository
|
||||
from store.repositories import build_thread_meta_repository
|
||||
from store.repositories import build_feedback_repository
|
||||
from store.repositories import build_run_event_repository
|
||||
from store.repositories import RunRepositoryProtocol
|
||||
from store.repositories import UserRepositoryProtocol
|
||||
```
|
||||
|
||||
App layer callers should use:
|
||||
|
||||
```python
|
||||
from app.infra.storage import StorageRunRepository
|
||||
from app.infra.storage import StorageUserDataRepository
|
||||
from app.infra.storage import StorageThreadMetaRepository
|
||||
from app.infra.storage import StorageFeedbackRepository
|
||||
from app.infra.storage import StorageRunEventRepository
|
||||
```
|
||||
|
||||
### Prohibited Calls
|
||||
|
||||
App/gateway/router/auth code must not import:
|
||||
|
||||
```python
|
||||
from store.repositories.db import DbRunRepository
|
||||
from store.repositories.models import Run
|
||||
from store.persistence.base_model import MappedBase
|
||||
```
|
||||
|
||||
Routers must not:
|
||||
|
||||
- create SQLAlchemy engines
|
||||
- create SQLAlchemy sessions directly
|
||||
- call storage DB repository classes directly
|
||||
- commit/rollback storage transactions directly unless explicitly scoped by an infra adapter
|
||||
- depend on storage SQLAlchemy model classes
|
||||
|
||||
Storage package code must not import:
|
||||
|
||||
```python
|
||||
import app.gateway
|
||||
import app.infra
|
||||
import deerflow.runtime
|
||||
```
|
||||
|
||||
The dependency direction is:
|
||||
|
||||
```text
|
||||
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
|
||||
```
|
||||
|
||||
The reverse direction is forbidden.
|
||||
|
||||
## Checkpointer Compatibility
|
||||
|
||||
The storage persistence bundle initializes the LangGraph checkpointer alongside application data persistence.
|
||||
|
||||
Backend-specific notes:
|
||||
|
||||
- SQLite uses `langgraph-checkpoint-sqlite`.
|
||||
- PostgreSQL uses `langgraph-checkpoint-postgres` and requires a string `postgresql://...` connection URL.
|
||||
- MySQL uses `langgraph-checkpoint-mysql` and requires a string MySQL connection URL.
|
||||
|
||||
SQLAlchemy may use async driver URLs such as `postgresql+asyncpg://...` or `mysql+aiomysql://...`, but LangGraph checkpointer constructors expect plain string connection URLs. This conversion belongs inside the storage driver implementation.
|
||||
|
||||
## JSON Metadata Filtering
|
||||
|
||||
Thread metadata search supports dialect-aware JSON filtering through `store.persistence.json_compat`.
|
||||
|
||||
The matcher supports:
|
||||
|
||||
- `None`
|
||||
- `bool`
|
||||
- `int`
|
||||
- `float`
|
||||
- `str`
|
||||
|
||||
It rejects:
|
||||
|
||||
- unsafe keys
|
||||
- nested JSON path expressions
|
||||
- dict/list values
|
||||
- integers outside signed 64-bit range
|
||||
|
||||
This prevents SQL/JSON path injection, avoids compiled-cache type drift, and preserves type semantics such as `True != 1` and explicit JSON `null` not matching a missing key.
|
||||
|
||||
## Step-by-Step Implementation Plan
|
||||
|
||||
### Step 1: Introduce Storage Package Foundation
|
||||
|
||||
- Add `backend/packages/storage`.
|
||||
- Add storage config models.
|
||||
- Add `AppPersistence`.
|
||||
- Add SQLite/PostgreSQL/MySQL persistence drivers.
|
||||
- Add repository contracts, models, DB implementations, and factory helpers.
|
||||
- Add package dependency wiring.
|
||||
- Exclude cron persistence.
|
||||
|
||||
### Step 2: Harden Storage Backend Compatibility
|
||||
|
||||
- Validate SQLite setup and repository behavior.
|
||||
- Validate PostgreSQL and MySQL with local E2E tests.
|
||||
- Fix checkpointer connection-string compatibility.
|
||||
- Fix PostgreSQL locking and aggregation differences.
|
||||
- Add dialect-aware JSON metadata filtering.
|
||||
|
||||
### Step 3: Add App Infra Adapters
|
||||
|
||||
- Add `backend/app/infra/storage`.
|
||||
- Implement app-facing repositories that own session lifecycle.
|
||||
- Keep storage contracts as the only data access boundary.
|
||||
- Add legacy compatibility adapters for existing app/gateway method shapes.
|
||||
- Keep app/gateway imports out of `packages/storage`.
|
||||
|
||||
### Step 4: Switch FastAPI Runtime Injection
|
||||
|
||||
- Initialize storage persistence in FastAPI startup/lifespan.
|
||||
- Attach `persistence`, `checkpointer`, and `session_factory` to `app.state`.
|
||||
- Preserve existing external state names:
|
||||
- `run_store`
|
||||
- `feedback_repo`
|
||||
- `thread_store`
|
||||
- `run_event_store`
|
||||
- `checkpointer`
|
||||
- `session_factory`
|
||||
- Start with user/auth provider construction, then migrate run/thread/feedback/run_event.
|
||||
|
||||
### Step 5: Router and Auth Compatibility
|
||||
|
||||
- Ensure routers consume app-facing adapters, not storage DB classes.
|
||||
- Ensure auth providers depend on user repository contracts.
|
||||
- Keep router response shapes unchanged.
|
||||
- Add focused auth/admin/router regression tests.
|
||||
|
||||
### Step 6: Cleanup Legacy Persistence
|
||||
|
||||
- Compare old persistence usage after app/gateway migration.
|
||||
- Remove unused old repository implementations only after all call sites move.
|
||||
- Keep compatibility shims only where needed for a transition window.
|
||||
- Delete memory backend paths from storage-owned durable persistence.
|
||||
|
||||
## Testing Strategy
|
||||
|
||||
Unit tests should cover:
|
||||
|
||||
- config parsing
|
||||
- persistence setup
|
||||
- table creation
|
||||
- repository CRUD/query behavior
|
||||
- typed JSON metadata filtering
|
||||
- dialect SQL compilation
|
||||
- cron exclusion
|
||||
|
||||
E2E tests should cover:
|
||||
|
||||
- SQLite persistence setup
|
||||
- PostgreSQL temporary database setup
|
||||
- MySQL temporary database setup
|
||||
- repository contract behavior across all supported SQL backends
|
||||
- JSON/Unicode round trip
|
||||
- rollback behavior
|
||||
- persistence close/cleanup
|
||||
|
||||
E2E tests may remain local-only if CI does not provide PostgreSQL/MySQL services.
|
||||
@@ -0,0 +1,401 @@
|
||||
# Storage Package 设计文档
|
||||
|
||||
## 背景
|
||||
|
||||
DeerFlow 当前有多类持久化职责分散在 app、gateway、runtime 和旧 persistence 模块中。这会带来几个问题:
|
||||
|
||||
- routers 和 runtime services 容易依赖具体 persistence 实现,而不是稳定契约。
|
||||
- user/auth、run metadata、thread metadata、feedback、run events、checkpointer setup 的初始化路径不统一。
|
||||
- memory、SQLite、PostgreSQL 相关路径中存在部分重复逻辑。
|
||||
- app 层代码和 storage 层代码耦合,导致增量迁移困难。
|
||||
- 增加或验证新的 SQL backend 时,需要改动 app/runtime,而不是只改 storage package。
|
||||
|
||||
引入 storage package 的目标,是把应用数据持久化抽象成 package 级能力,并提供明确契约、清晰边界和 SQL backend 兼容性。
|
||||
|
||||
## 目标
|
||||
|
||||
- 新增独立的 `packages/storage`,负责 durable application data。
|
||||
- 通过统一 persistence 构造流程支持 SQLite、PostgreSQL、MySQL。
|
||||
- 保持 LangGraph checkpointer 与同一个数据库 backend 兼容。
|
||||
- 将 repository contracts 作为 package 对外唯一数据访问边界。
|
||||
- app 层通过 `app.infra.storage` 适配 storage,而不是直接依赖 storage DB 实现类。
|
||||
- 支持 app/gateway 后续小步迁移,避免一次性大重构。
|
||||
|
||||
## 非目标
|
||||
|
||||
- 第一阶段不删除旧 persistence。
|
||||
- 不让 routers 直接依赖 storage package models。
|
||||
- 不让 app routers 管理 SQLAlchemy sessions。
|
||||
- cron persistence 不属于 storage package 基础迁移范围。
|
||||
- memory backend 不属于 durable storage package。若 app runtime 仍需要 memory 兼容,应放在 `packages/storage` 之外。
|
||||
|
||||
## Storage 设计理念
|
||||
|
||||
### Package 自己负责 Durable Storage
|
||||
|
||||
`packages/storage` 负责应用数据的 durable persistence,包括:
|
||||
|
||||
- storage 持久化配置
|
||||
- SQLAlchemy models
|
||||
- repository contracts 和 DTOs
|
||||
- SQL repository 实现
|
||||
- persistence factory functions
|
||||
- 面向现有 config 的兼容初始化入口
|
||||
|
||||
该 package 不应该 import `app.gateway`、routers、auth providers 或 runtime 中的 gateway 对象。
|
||||
|
||||
### SQL Backend 兼容
|
||||
|
||||
该 package 支持三种 SQL backend:
|
||||
|
||||
- SQLite:本地或单节点部署
|
||||
- PostgreSQL:生产多节点部署
|
||||
- MySQL:使用 MySQL 作为标准数据库的部署
|
||||
|
||||
backend 差异在 storage package 内部处理:
|
||||
|
||||
- SQLAlchemy async engine URL 构造
|
||||
- LangGraph checkpointer 连接串兼容
|
||||
- SQLite/PostgreSQL/MySQL 的 JSON metadata filter
|
||||
- 不同 SQL 方言在 locking、aggregation、JSON 类型语义上的差异
|
||||
|
||||
### 统一 Persistence Bundle
|
||||
|
||||
Storage 初始化返回 `AppPersistence` bundle:
|
||||
|
||||
```python
|
||||
@dataclass(slots=True)
|
||||
class AppPersistence:
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
setup: Callable[[], Awaitable[None]]
|
||||
aclose: Callable[[], Awaitable[None]]
|
||||
```
|
||||
|
||||
app runtime 只需要初始化一次 persistence,调用 `setup()`,然后注入:
|
||||
|
||||
- `checkpointer`
|
||||
- `session_factory`
|
||||
- repository adapters
|
||||
|
||||
这样 checkpointer 和应用数据可以对齐到同一个 backend,同时 routers 不需要理解数据库配置。
|
||||
|
||||
## Package 结构
|
||||
|
||||
```text
|
||||
backend/packages/storage/
|
||||
store/
|
||||
config/
|
||||
storage_config.py
|
||||
app_config.py
|
||||
persistence/
|
||||
factory.py
|
||||
types.py
|
||||
base_model.py
|
||||
json_compat.py
|
||||
drivers/
|
||||
sqlite.py
|
||||
postgres.py
|
||||
mysql.py
|
||||
repositories/
|
||||
contracts/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
models/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
db/
|
||||
user.py
|
||||
run.py
|
||||
thread_meta.py
|
||||
feedback.py
|
||||
run_event.py
|
||||
factory.py
|
||||
```
|
||||
|
||||
## Persistence 构造
|
||||
|
||||
storage 的主要入口:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_storage_config
|
||||
|
||||
persistence = await create_persistence_from_storage_config(storage_config)
|
||||
await persistence.setup()
|
||||
```
|
||||
|
||||
为了兼容现有 app database config,也提供:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
|
||||
persistence = await create_persistence_from_database_config(config.database)
|
||||
await persistence.setup()
|
||||
```
|
||||
|
||||
预期 app startup 流程:
|
||||
|
||||
```python
|
||||
persistence = await create_persistence_from_database_config(config.database)
|
||||
await persistence.setup()
|
||||
|
||||
app.state.persistence = persistence
|
||||
app.state.checkpointer = persistence.checkpointer
|
||||
app.state.session_factory = persistence.session_factory
|
||||
```
|
||||
|
||||
预期 app shutdown 流程:
|
||||
|
||||
```python
|
||||
await app.state.persistence.aclose()
|
||||
```
|
||||
|
||||
## Repository 契约设计
|
||||
|
||||
Repository contracts 是 storage package 对外公开的数据访问边界。它们位于 `store.repositories.contracts`,并通过 `store.repositories` re-export。
|
||||
|
||||
主要契约包括:
|
||||
|
||||
- `UserRepositoryProtocol`
|
||||
- `RunRepositoryProtocol`
|
||||
- `ThreadMetaRepositoryProtocol`
|
||||
- `FeedbackRepositoryProtocol`
|
||||
- `RunEventRepositoryProtocol`
|
||||
|
||||
每组契约包含:
|
||||
|
||||
- 输入 DTO,例如 `UserCreate`、`RunCreate`、`ThreadMetaCreate`
|
||||
- 输出 DTO,例如 `User`、`Run`、`ThreadMeta`
|
||||
- repository protocol methods
|
||||
- 必要的领域异常,例如 `InvalidMetadataFilterError`
|
||||
|
||||
Repository 通过 session 构造:
|
||||
|
||||
```python
|
||||
from store.repositories import build_run_repository
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
run = await repo.get_run(run_id)
|
||||
```
|
||||
|
||||
这样可以让 transaction ownership 保持明确。storage package 不通过全局 singleton 隐式隐藏 commit 或 session 生命周期。
|
||||
|
||||
## App/Infra 调用契约
|
||||
|
||||
app 层不应该直接调用 `store.repositories.db.*`。预期的 app 边界是 `app.infra.storage`。
|
||||
|
||||
`app.infra.storage` 负责:
|
||||
|
||||
- 从 FastAPI runtime 初始化中接收 `session_factory`
|
||||
- 为 app-facing repository methods 管理 session 生命周期
|
||||
- 在必要时将 storage DTOs 转成 app/gateway DTOs
|
||||
- 迁移期间保留现有 app-facing 名称
|
||||
- 依赖 storage repository protocols,而不是具体 DB classes
|
||||
|
||||
预期 adapter 模式:
|
||||
|
||||
```python
|
||||
class StorageRunRepository(RunRepositoryProtocol):
|
||||
def __init__(self, session_factory):
|
||||
self._session_factory = session_factory
|
||||
|
||||
async def get_run(self, run_id: str):
|
||||
async with self._session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
return await repo.get_run(run_id)
|
||||
```
|
||||
|
||||
为了兼容 gateway,app state 可以暂时保持现有名字,只替换内部实现:
|
||||
|
||||
```python
|
||||
app.state.run_store = StorageRunStore(run_repository)
|
||||
app.state.feedback_repo = StorageFeedbackStore(feedback_repository)
|
||||
app.state.thread_store = StorageThreadMetaStore(thread_meta_repository)
|
||||
app.state.run_event_store = StorageRunEventStore(run_event_repository)
|
||||
app.state.checkpointer = persistence.checkpointer
|
||||
app.state.session_factory = persistence.session_factory
|
||||
```
|
||||
|
||||
app-facing objects 可以在迁移期间保留旧方法名,但内部数据访问必须经过 storage contracts。
|
||||
|
||||
## 边界规则
|
||||
|
||||
### 允许调用的范围
|
||||
|
||||
storage package 调用方可以使用:
|
||||
|
||||
```python
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.persistence import create_persistence_from_storage_config
|
||||
from store.repositories import build_run_repository
|
||||
from store.repositories import build_user_repository
|
||||
from store.repositories import build_thread_meta_repository
|
||||
from store.repositories import build_feedback_repository
|
||||
from store.repositories import build_run_event_repository
|
||||
from store.repositories import RunRepositoryProtocol
|
||||
from store.repositories import UserRepositoryProtocol
|
||||
```
|
||||
|
||||
app 层应该使用:
|
||||
|
||||
```python
|
||||
from app.infra.storage import StorageRunRepository
|
||||
from app.infra.storage import StorageUserDataRepository
|
||||
from app.infra.storage import StorageThreadMetaRepository
|
||||
from app.infra.storage import StorageFeedbackRepository
|
||||
from app.infra.storage import StorageRunEventRepository
|
||||
```
|
||||
|
||||
### 禁止调用的范围
|
||||
|
||||
app/gateway/router/auth 代码不应该 import:
|
||||
|
||||
```python
|
||||
from store.repositories.db import DbRunRepository
|
||||
from store.repositories.models import Run
|
||||
from store.persistence.base_model import MappedBase
|
||||
```
|
||||
|
||||
routers 禁止:
|
||||
|
||||
- 创建 SQLAlchemy engines
|
||||
- 直接创建 SQLAlchemy sessions
|
||||
- 直接调用 storage DB repository classes
|
||||
- 直接 commit/rollback storage transactions,除非这是 infra adapter 明确管理的范围
|
||||
- 依赖 storage SQLAlchemy model classes
|
||||
|
||||
storage package 禁止 import:
|
||||
|
||||
```python
|
||||
import app.gateway
|
||||
import app.infra
|
||||
import deerflow.runtime
|
||||
```
|
||||
|
||||
依赖方向必须是:
|
||||
|
||||
```text
|
||||
app/gateway -> app.infra.storage -> packages/storage contracts/factories -> packages/storage db implementations
|
||||
```
|
||||
|
||||
禁止反向依赖。
|
||||
|
||||
## Checkpointer 兼容
|
||||
|
||||
storage persistence bundle 会同时初始化 LangGraph checkpointer 和应用数据持久化。
|
||||
|
||||
backend 说明:
|
||||
|
||||
- SQLite 使用 `langgraph-checkpoint-sqlite`。
|
||||
- PostgreSQL 使用 `langgraph-checkpoint-postgres`,需要字符串形式的 `postgresql://...` 连接串。
|
||||
- MySQL 使用 `langgraph-checkpoint-mysql`,需要字符串形式的 MySQL 连接串。
|
||||
|
||||
SQLAlchemy 可以使用 `postgresql+asyncpg://...` 或 `mysql+aiomysql://...` 这类 async driver URL,但 LangGraph checkpointer 构造函数需要普通字符串连接串。这个转换应该封装在 storage driver implementation 内部。
|
||||
|
||||
## JSON Metadata Filtering
|
||||
|
||||
Thread metadata search 通过 `store.persistence.json_compat` 支持跨方言 JSON filtering。
|
||||
|
||||
支持的 filter value 类型:
|
||||
|
||||
- `None`
|
||||
- `bool`
|
||||
- `int`
|
||||
- `float`
|
||||
- `str`
|
||||
|
||||
拒绝:
|
||||
|
||||
- unsafe keys
|
||||
- nested JSON path expressions
|
||||
- dict/list values
|
||||
- 超出 signed 64-bit 范围的整数
|
||||
|
||||
这样可以避免 SQL/JSON path injection,避免 compiled-cache 类型漂移,并保留类型语义,例如 `True != 1`,显式 JSON `null` 不等于 missing key。
|
||||
|
||||
## 分步实现方案
|
||||
|
||||
### 第 1 步:新增 Storage Package 基础
|
||||
|
||||
- 新增 `backend/packages/storage`。
|
||||
- 增加 storage config models。
|
||||
- 增加 `AppPersistence`。
|
||||
- 增加 SQLite/PostgreSQL/MySQL persistence drivers。
|
||||
- 增加 repository contracts、models、DB implementations 和 factory helpers。
|
||||
- 接入 package dependency。
|
||||
- 排除 cron persistence。
|
||||
|
||||
### 第 2 步:补齐 Storage Backend 兼容性
|
||||
|
||||
- 验证 SQLite setup 和 repository 行为。
|
||||
- 使用本地 E2E 验证 PostgreSQL 和 MySQL。
|
||||
- 修复 checkpointer 连接串兼容。
|
||||
- 修复 PostgreSQL locking 和 aggregation 差异。
|
||||
- 增加跨方言 JSON metadata filtering。
|
||||
|
||||
### 第 3 步:新增 App Infra Adapters
|
||||
|
||||
- 新增 `backend/app/infra/storage`。
|
||||
- 实现 app-facing repositories,由它们管理 session 生命周期。
|
||||
- 保持 storage contracts 作为唯一数据访问边界。
|
||||
- 为现有 app/gateway method shape 增加兼容 adapters。
|
||||
- 避免 `packages/storage` import app/gateway。
|
||||
|
||||
### 第 4 步:切换 FastAPI Runtime 注入
|
||||
|
||||
- 在 FastAPI startup/lifespan 中初始化 storage persistence。
|
||||
- 将 `persistence`、`checkpointer`、`session_factory` 注入 `app.state`。
|
||||
- 暂时保留现有对外 state 名称:
|
||||
- `run_store`
|
||||
- `feedback_repo`
|
||||
- `thread_store`
|
||||
- `run_event_store`
|
||||
- `checkpointer`
|
||||
- `session_factory`
|
||||
- 先切 user/auth provider 构造,再逐步迁移 run/thread/feedback/run_event。
|
||||
|
||||
### 第 5 步:Router 和 Auth 兼容
|
||||
|
||||
- 确保 routers 消费 app-facing adapters,而不是 storage DB classes。
|
||||
- 确保 auth providers 依赖 user repository contracts。
|
||||
- 保持 router response shapes 不变。
|
||||
- 增加 auth/admin/router regression tests。
|
||||
|
||||
### 第 6 步:清理旧 Persistence
|
||||
|
||||
- app/gateway 迁移完成后,再比较旧 persistence usage。
|
||||
- 所有 call sites 迁移完成后,再删除未使用的旧 repository implementations。
|
||||
- 只在必要时保留短期 compatibility shims。
|
||||
- 从 storage-owned durable persistence 中移除 memory backend 路径。
|
||||
|
||||
## 测试策略
|
||||
|
||||
单测应覆盖:
|
||||
|
||||
- config parsing
|
||||
- persistence setup
|
||||
- table creation
|
||||
- repository CRUD/query behavior
|
||||
- typed JSON metadata filtering
|
||||
- dialect SQL compilation
|
||||
- cron exclusion
|
||||
|
||||
E2E 应覆盖:
|
||||
|
||||
- SQLite persistence setup
|
||||
- PostgreSQL temporary database setup
|
||||
- MySQL temporary database setup
|
||||
- 所有支持 SQL backend 下的 repository contract 行为
|
||||
- JSON/Unicode round trip
|
||||
- rollback behavior
|
||||
- persistence close/cleanup
|
||||
|
||||
如果 CI 暂时没有 PostgreSQL/MySQL services,E2E 可以先作为 local-only 验证保留。
|
||||
+22
-27
@@ -104,46 +104,45 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
return "[Tool call was interrupted and did not return a result.]"
|
||||
|
||||
def _build_patched_messages(self, messages: list) -> list | None:
|
||||
"""Return messages with tool results grouped after their tool-call AIMessage.
|
||||
"""Return a new message list with patches inserted at the correct positions.
|
||||
|
||||
This normalizes model-bound causal order before provider serialization while
|
||||
preserving already-valid transcripts unchanged.
|
||||
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||
Returns None if no patches are needed.
|
||||
"""
|
||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
||||
# Collect IDs of all existing ToolMessages
|
||||
existing_tool_msg_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
||||
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||
|
||||
tool_call_ids: set[str] = set()
|
||||
# Check if any patching is needed
|
||||
needs_patch = False
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if tc_id:
|
||||
tool_call_ids.add(tc_id)
|
||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||
needs_patch = True
|
||||
break
|
||||
if needs_patch:
|
||||
break
|
||||
|
||||
if not needs_patch:
|
||||
return None
|
||||
|
||||
# Build new list with patches inserted right after each dangling AIMessage
|
||||
patched: list = []
|
||||
consumed_tool_msg_ids: set[str] = set()
|
||||
patched_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
||||
continue
|
||||
|
||||
patched.append(msg)
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
||||
continue
|
||||
|
||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
||||
if existing_tool_msg is not None:
|
||||
patched.append(existing_tool_msg)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
else:
|
||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
content=self._synthetic_tool_message_content(tc),
|
||||
@@ -152,14 +151,10 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
patched_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
if patched == messages:
|
||||
return None
|
||||
|
||||
if patch_count:
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
return patched
|
||||
|
||||
@override
|
||||
|
||||
@@ -10,7 +10,6 @@ from typing import Any, Protocol, override, runtime_checkable
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import SummarizationMiddleware
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
from langchain_core.messages.utils import get_buffer_string
|
||||
from langgraph.config import get_config
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
@@ -176,84 +175,12 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
]
|
||||
}
|
||||
|
||||
@override
|
||||
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary without emitting streaming events to the client.
|
||||
|
||||
Suppresses callbacks to prevent the internal summarization LLM call from
|
||||
producing visible AI message chunks in the frontend's ``messages-tuple``
|
||||
stream (issue #2804).
|
||||
"""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
formatted = get_buffer_string(trimmed)
|
||||
|
||||
try:
|
||||
response = self.model.with_config(callbacks=[]).invoke(
|
||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
||||
config={
|
||||
"metadata": {"lc_source": "summarization"},
|
||||
"callbacks": [],
|
||||
},
|
||||
)
|
||||
return self._extract_summary_text(response)
|
||||
except Exception as e:
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
@override
|
||||
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary without emitting streaming events to the client.
|
||||
|
||||
Suppresses callbacks to prevent the internal summarization LLM call from
|
||||
producing visible AI message chunks in the frontend's ``messages-tuple``
|
||||
stream (issue #2804).
|
||||
"""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
formatted = get_buffer_string(trimmed)
|
||||
|
||||
try:
|
||||
response = await self.model.with_config(callbacks=[]).ainvoke(
|
||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
||||
config={
|
||||
"metadata": {"lc_source": "summarization"},
|
||||
"callbacks": [],
|
||||
},
|
||||
)
|
||||
return self._extract_summary_text(response)
|
||||
except Exception as e:
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
def _extract_summary_text(self, response: Any) -> str:
|
||||
# Prefer .text which normalizes list content blocks (e.g. [{"type": "text", "text": "..."}]).
|
||||
# Fall back to .content for non-LangChain responses.
|
||||
summary_text = getattr(response, "text", None)
|
||||
if summary_text is None:
|
||||
summary_text = getattr(response, "content", "")
|
||||
return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip()
|
||||
|
||||
@override
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
"""Override the base implementation to let the human message with the special name 'summary'.
|
||||
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
||||
"""
|
||||
return [
|
||||
HumanMessage(
|
||||
content=f"Here is a summary of the conversation to date:\n\n{summary}",
|
||||
name="summary",
|
||||
additional_kwargs={"hide_from_ui": True},
|
||||
)
|
||||
]
|
||||
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||
|
||||
def _preserve_dynamic_context_reminders(
|
||||
self,
|
||||
|
||||
@@ -7,21 +7,17 @@ reminder message so the model still knows about the outstanding todo list.
|
||||
|
||||
Additionally, this middleware prevents the agent from exiting the loop while
|
||||
there are still incomplete todo items. When the model produces a final response
|
||||
(no tool calls) but todos are not yet complete, the middleware queues a reminder
|
||||
for the next model request and jumps back to the model node to force continued
|
||||
engagement. The completion reminder is injected via ``wrap_model_call`` instead
|
||||
of being persisted into graph state as a normal user-visible message.
|
||||
(no tool calls) but todos are not yet complete, the middleware injects a reminder
|
||||
and jumps back to the model node to force continued engagement.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
|
||||
from langchain.agents.middleware.types import hook_config
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
@@ -59,51 +55,6 @@ def _format_todos(todos: list[Todo]) -> str:
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_completion_reminder(todos: list[Todo]) -> str:
|
||||
"""Format a completion reminder for incomplete todo items."""
|
||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
||||
return (
|
||||
"<system_reminder>\n"
|
||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
||||
f"{incomplete_text}\n\n"
|
||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
||||
"as you finish them, and only respond when all items are done.\n"
|
||||
"</system_reminder>"
|
||||
)
|
||||
|
||||
|
||||
_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"}
|
||||
|
||||
|
||||
def _has_tool_call_intent_or_error(message: AIMessage) -> bool:
|
||||
"""Return True when an AIMessage is not a clean final answer.
|
||||
|
||||
Todo completion reminders should only fire when the model has produced a
|
||||
plain final response. Provider/tool parsing details have moved across
|
||||
LangChain versions and integrations, so keep all tool-intent/error signals
|
||||
behind this helper instead of checking one concrete field at the call site.
|
||||
"""
|
||||
if message.tool_calls:
|
||||
return True
|
||||
|
||||
if getattr(message, "invalid_tool_calls", None):
|
||||
return True
|
||||
|
||||
# Backward/provider compatibility: some integrations preserve raw or legacy
|
||||
# tool-call intent in additional_kwargs even when structured tool_calls is
|
||||
# empty. If this helper changes, update the matching sentinel test
|
||||
# `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`;
|
||||
# if that test fails after a LangChain upgrade, review this helper so new
|
||||
# tool-call/error fields are not silently treated as clean final answers.
|
||||
additional_kwargs = getattr(message, "additional_kwargs", {}) or {}
|
||||
if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"):
|
||||
return True
|
||||
|
||||
response_metadata = getattr(message, "response_metadata", {}) or {}
|
||||
return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS
|
||||
|
||||
|
||||
class TodoMiddleware(TodoListMiddleware):
|
||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||
|
||||
@@ -138,7 +89,6 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
formatted = _format_todos(todos)
|
||||
reminder = HumanMessage(
|
||||
name="todo_reminder",
|
||||
additional_kwargs={"hide_from_ui": True},
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"Your todo list from earlier is no longer visible in the current context window, "
|
||||
@@ -163,100 +113,6 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
# Maximum number of completion reminders before allowing the agent to exit.
|
||||
# This prevents infinite loops when the agent cannot make further progress.
|
||||
_MAX_COMPLETION_REMINDERS = 2
|
||||
# Hard cap for per-run reminder bookkeeping in long-lived middleware instances.
|
||||
_MAX_COMPLETION_REMINDER_KEYS = 4096
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._lock = threading.Lock()
|
||||
self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {}
|
||||
self._completion_reminder_counts: dict[tuple[str, str], int] = {}
|
||||
self._completion_reminder_touch_order: dict[tuple[str, str], int] = {}
|
||||
self._completion_reminder_next_order = 0
|
||||
|
||||
@staticmethod
|
||||
def _get_thread_id(runtime: Runtime) -> str:
|
||||
context = getattr(runtime, "context", None)
|
||||
thread_id = context.get("thread_id") if context else None
|
||||
return str(thread_id) if thread_id else "default"
|
||||
|
||||
@staticmethod
|
||||
def _get_run_id(runtime: Runtime) -> str:
|
||||
context = getattr(runtime, "context", None)
|
||||
run_id = context.get("run_id") if context else None
|
||||
return str(run_id) if run_id else "default"
|
||||
|
||||
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
|
||||
return self._get_thread_id(runtime), self._get_run_id(runtime)
|
||||
|
||||
def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
||||
self._completion_reminder_next_order += 1
|
||||
self._completion_reminder_touch_order[key] = self._completion_reminder_next_order
|
||||
|
||||
def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]:
|
||||
keys = set(self._pending_completion_reminders)
|
||||
keys.update(self._completion_reminder_counts)
|
||||
keys.update(self._completion_reminder_touch_order)
|
||||
return keys
|
||||
|
||||
def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
||||
self._pending_completion_reminders.pop(key, None)
|
||||
self._completion_reminder_counts.pop(key, None)
|
||||
self._completion_reminder_touch_order.pop(key, None)
|
||||
|
||||
def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None:
|
||||
keys = self._completion_reminder_keys_locked()
|
||||
overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS
|
||||
if overflow <= 0:
|
||||
return
|
||||
|
||||
candidates = [key for key in keys if key != protected_key]
|
||||
candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0))
|
||||
for key in candidates[:overflow]:
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
self._pending_completion_reminders.setdefault(key, []).append(reminder)
|
||||
self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1
|
||||
self._touch_completion_reminder_key_locked(key)
|
||||
self._prune_completion_reminder_state_locked(protected_key=key)
|
||||
|
||||
def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
return self._completion_reminder_counts.get(key, 0)
|
||||
|
||||
def _drain_completion_reminders(self, runtime: Runtime) -> list[str]:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
reminders = self._pending_completion_reminders.pop(key, [])
|
||||
if reminders or key in self._completion_reminder_counts:
|
||||
self._touch_completion_reminder_key_locked(key)
|
||||
return reminders
|
||||
|
||||
def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None:
|
||||
thread_id, current_run_id = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
for key in self._completion_reminder_keys_locked():
|
||||
if key[0] == thread_id and key[1] != current_run_id:
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_other_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_other_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["model"])
|
||||
@override
|
||||
@@ -281,12 +137,10 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
if base_result is not None:
|
||||
return base_result
|
||||
|
||||
# 2. Only intervene when the agent wants to exit cleanly. Tool-call
|
||||
# intent or tool-call parse errors should be handled by the tool path
|
||||
# instead of being masked by todo reminders.
|
||||
# 2. Only intervene when the agent wants to exit (no tool calls).
|
||||
messages = state.get("messages") or []
|
||||
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
|
||||
if not last_ai or _has_tool_call_intent_or_error(last_ai):
|
||||
if not last_ai or last_ai.tool_calls:
|
||||
return None
|
||||
|
||||
# 3. Allow exit when all todos are completed or there are no todos.
|
||||
@@ -295,14 +149,24 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
return None
|
||||
|
||||
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
|
||||
if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS:
|
||||
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
|
||||
return None
|
||||
|
||||
# 5. Queue a reminder for the next model request and jump back. We must
|
||||
# not persist this control prompt as a normal HumanMessage, otherwise it
|
||||
# can leak into user-visible message streams and saved transcripts.
|
||||
self._queue_completion_reminder(runtime, _format_completion_reminder(todos))
|
||||
return {"jump_to": "model"}
|
||||
# 5. Inject a reminder and force the agent back to the model.
|
||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
||||
reminder = HumanMessage(
|
||||
name="todo_completion_reminder",
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
||||
f"{incomplete_text}\n\n"
|
||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
||||
"as you finish them, and only respond when all items are done.\n"
|
||||
"</system_reminder>"
|
||||
),
|
||||
)
|
||||
return {"jump_to": "model", "messages": [reminder]}
|
||||
|
||||
@override
|
||||
@hook_config(can_jump_to=["model"])
|
||||
@@ -313,47 +177,3 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of after_model."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@staticmethod
|
||||
def _format_pending_completion_reminders(reminders: list[str]) -> str:
|
||||
return "\n\n".join(dict.fromkeys(reminders))
|
||||
|
||||
def _augment_request(self, request: ModelRequest) -> ModelRequest:
|
||||
reminders = self._drain_completion_reminders(request.runtime)
|
||||
if not reminders:
|
||||
return request
|
||||
new_messages = [
|
||||
*request.messages,
|
||||
HumanMessage(
|
||||
content=self._format_pending_completion_reminders(reminders),
|
||||
name="todo_completion_reminder",
|
||||
additional_kwargs={"hide_from_ui": True},
|
||||
),
|
||||
]
|
||||
return request.override(messages=new_messages)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(self._augment_request(request))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await handler(self._augment_request(request))
|
||||
|
||||
@override
|
||||
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_current_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_current_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@@ -35,7 +35,7 @@ def _token_usage_cache_enabled(app_config: "AppConfig | None") -> bool:
|
||||
if app_config is None:
|
||||
try:
|
||||
app_config = get_app_config()
|
||||
except FileNotFoundError:
|
||||
except (FileNotFoundError, ValueError):
|
||||
return False
|
||||
return bool(getattr(getattr(app_config, "token_usage", None), "enabled", False))
|
||||
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
[project]
|
||||
name = "deerflow-storage"
|
||||
version = "0.1.0"
|
||||
description = "DeerFlow storage framework"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"dotenv>=0.9.9",
|
||||
"pydantic>=2.12.5",
|
||||
"pyyaml>=6.0.3",
|
||||
"sqlalchemy[asyncio]>=2.0,<3.0",
|
||||
"alembic>=1.13",
|
||||
"langgraph>=1.1.9",
|
||||
]
|
||||
[project.optional-dependencies]
|
||||
postgres = [
|
||||
"asyncpg>=0.29",
|
||||
"langgraph-checkpoint-postgres>=3.0.5",
|
||||
"psycopg[binary]>=3.3.3",
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
mysql = [
|
||||
"aiomysql>=0.2",
|
||||
"langgraph-checkpoint-mysql>=3.0.0",
|
||||
]
|
||||
sqlite = [
|
||||
"aiosqlite>=0.22.1",
|
||||
"langgraph-checkpoint-sqlite>=3.0.3"
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["store"]
|
||||
@@ -0,0 +1,5 @@
|
||||
from .enums import DataBaseType
|
||||
|
||||
__all__ = [
|
||||
"DataBaseType",
|
||||
]
|
||||
@@ -0,0 +1,41 @@
|
||||
from enum import Enum
|
||||
from enum import IntEnum as SourceIntEnum
|
||||
from enum import StrEnum as SourceStrEnum
|
||||
from typing import Any, TypeVar
|
||||
|
||||
T = TypeVar("T", bound=Enum)
|
||||
|
||||
|
||||
class _EnumBase:
|
||||
"""Base enum class with common utility methods."""
|
||||
|
||||
@classmethod
|
||||
def get_member_keys(cls) -> list[str]:
|
||||
"""Return a list of enum member names."""
|
||||
return list(cls.__members__.keys())
|
||||
|
||||
@classmethod
|
||||
def get_member_values(cls) -> list:
|
||||
"""Return a list of enum member values."""
|
||||
return [item.value for item in cls.__members__.values()]
|
||||
|
||||
@classmethod
|
||||
def get_member_dict(cls) -> dict[str, Any]:
|
||||
"""Return a dict mapping member names to values."""
|
||||
return {name: item.value for name, item in cls.__members__.items()}
|
||||
|
||||
|
||||
class IntEnum(_EnumBase, SourceIntEnum):
|
||||
"""Integer enum base class."""
|
||||
|
||||
|
||||
class StrEnum(_EnumBase, SourceStrEnum):
|
||||
"""String enum base class."""
|
||||
|
||||
|
||||
class DataBaseType(StrEnum):
|
||||
"""Database type."""
|
||||
|
||||
sqlite = "sqlite"
|
||||
mysql = "mysql"
|
||||
postgresql = "postgresql"
|
||||
@@ -0,0 +1,286 @@
|
||||
import logging
|
||||
import os
|
||||
from contextvars import ContextVar
|
||||
from pathlib import Path
|
||||
from typing import Any, Self
|
||||
|
||||
import yaml
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from store.config.storage_config import StorageConfig
|
||||
|
||||
load_dotenv()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _default_config_candidates() -> tuple[Path, ...]:
|
||||
"""Return deterministic config.yaml locations without relying on cwd."""
|
||||
backend_dir = Path(__file__).resolve().parents[4]
|
||||
repo_root = backend_dir.parent
|
||||
cwd = Path.cwd().resolve()
|
||||
candidates = (
|
||||
cwd / "config.yaml",
|
||||
backend_dir / "config.yaml",
|
||||
repo_root / "config.yaml",
|
||||
)
|
||||
return tuple(dict.fromkeys(candidates))
|
||||
|
||||
|
||||
def _storage_from_database_config(config_data: dict[str, Any]) -> None:
|
||||
"""Keep the existing public `database:` config compatible with storage."""
|
||||
if "storage" in config_data:
|
||||
return
|
||||
|
||||
database = config_data.get("database")
|
||||
if not isinstance(database, dict):
|
||||
return
|
||||
|
||||
backend = database.get("backend")
|
||||
if backend == "memory":
|
||||
raise ValueError("database.backend='memory' is not supported by storage; handle memory mode before loading storage config")
|
||||
|
||||
storage: dict[str, Any] = {
|
||||
"driver": "postgres" if backend == "postgres" else backend,
|
||||
"sqlite_dir": database.get("sqlite_dir", ".deer-flow/data"),
|
||||
"echo_sql": database.get("echo_sql", False),
|
||||
"pool_size": database.get("pool_size", 5),
|
||||
}
|
||||
|
||||
postgres_url = database.get("postgres_url")
|
||||
if backend == "postgres" and isinstance(postgres_url, str) and postgres_url:
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
parsed = make_url(postgres_url)
|
||||
storage["database_url"] = postgres_url
|
||||
storage.update(
|
||||
{
|
||||
"username": parsed.username or "",
|
||||
"password": parsed.password or "",
|
||||
"host": parsed.host or "localhost",
|
||||
"port": parsed.port or 5432,
|
||||
"db_name": parsed.database or "deerflow",
|
||||
}
|
||||
)
|
||||
|
||||
config_data["storage"] = storage
|
||||
|
||||
|
||||
class AppConfig(BaseModel):
|
||||
"""DeerFlow application configuration."""
|
||||
|
||||
timezone: str = Field(default="UTC", description="Timezone for scheduling and timestamps (e.g. 'UTC', 'America/New_York')")
|
||||
log_level: str = Field(default="info", description="Logging level for deerflow modules (debug/info/warning/error)")
|
||||
storage: StorageConfig = Field(default=StorageConfig())
|
||||
model_config = ConfigDict(extra="allow", frozen=False)
|
||||
|
||||
@classmethod
|
||||
def resolve_config_path(cls, config_path: str | None = None) -> Path:
|
||||
"""Resolve the config file path.
|
||||
|
||||
Priority:
|
||||
1. If provided `config_path` argument, use it.
|
||||
2. If provided `DEER_FLOW_CONFIG_PATH` environment variable, use it.
|
||||
3. Otherwise, search deterministic backend/repository-root defaults from `_default_config_candidates()`.
|
||||
"""
|
||||
if config_path:
|
||||
path = Path(config_path)
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by param `config_path` not found at {path}")
|
||||
return path
|
||||
elif os.getenv("DEER_FLOW_CONFIG_PATH"):
|
||||
path = Path(os.getenv("DEER_FLOW_CONFIG_PATH"))
|
||||
if not Path.exists(path):
|
||||
raise FileNotFoundError(f"Config file specified by environment variable `DEER_FLOW_CONFIG_PATH` not found at {path}")
|
||||
return path
|
||||
else:
|
||||
for path in _default_config_candidates():
|
||||
if path.exists():
|
||||
return path
|
||||
raise FileNotFoundError("`config.yaml` file not found at the default backend or repository root locations")
|
||||
|
||||
@classmethod
|
||||
def from_file(cls, config_path: str | None = None) -> Self:
|
||||
"""Load and validate config from YAML. See `resolve_config_path` for path resolution."""
|
||||
resolved_path = cls.resolve_config_path(config_path)
|
||||
with open(resolved_path, encoding="utf-8") as f:
|
||||
config_data = yaml.safe_load(f) or {}
|
||||
|
||||
cls._check_config_version(config_data, resolved_path)
|
||||
|
||||
config_data = cls.resolve_env_variables(config_data)
|
||||
_storage_from_database_config(config_data)
|
||||
|
||||
if os.getenv("TIMEZONE"):
|
||||
config_data["timezone"] = os.getenv("TIMEZONE")
|
||||
|
||||
result = cls.model_validate(config_data)
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _check_config_version(cls, config_data: dict, config_path: Path) -> None:
|
||||
"""Check if the user's config.yaml is outdated compared to config.example.yaml.
|
||||
|
||||
Emits a warning if the user's config_version is lower than the example's.
|
||||
Missing config_version is treated as version 0 (pre-versioning).
|
||||
"""
|
||||
try:
|
||||
user_version = int(config_data.get("config_version", 0))
|
||||
except (TypeError, ValueError):
|
||||
user_version = 0
|
||||
|
||||
# Find config.example.yaml by searching config.yaml's directory and its parents
|
||||
example_path = None
|
||||
search_dir = config_path.parent
|
||||
for _ in range(5): # search up to 5 levels
|
||||
candidate = search_dir / "config.example.yaml"
|
||||
if candidate.exists():
|
||||
example_path = candidate
|
||||
break
|
||||
parent = search_dir.parent
|
||||
if parent == search_dir:
|
||||
break
|
||||
search_dir = parent
|
||||
if example_path is None:
|
||||
return
|
||||
|
||||
try:
|
||||
with open(example_path, encoding="utf-8") as f:
|
||||
example_data = yaml.safe_load(f)
|
||||
raw = example_data.get("config_version", 0) if example_data else 0
|
||||
try:
|
||||
example_version = int(raw)
|
||||
except (TypeError, ValueError):
|
||||
example_version = 0
|
||||
except Exception:
|
||||
return
|
||||
|
||||
if user_version < example_version:
|
||||
logger.warning(
|
||||
"Your config.yaml (version %d) is outdated — the latest version is %d. Run `make config-upgrade` to merge new fields into your config.",
|
||||
user_version,
|
||||
example_version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def resolve_env_variables(cls, config: Any) -> Any:
|
||||
"""Recursively replace $VAR strings with their environment variable values (e.g. $OPENAI_API_KEY)."""
|
||||
if isinstance(config, str):
|
||||
if config.startswith("$"):
|
||||
env_value = os.getenv(config[1:])
|
||||
if env_value is None:
|
||||
raise ValueError(f"Environment variable {config[1:]} not found for config value {config}")
|
||||
return env_value
|
||||
return config
|
||||
elif isinstance(config, dict):
|
||||
return {k: cls.resolve_env_variables(v) for k, v in config.items()}
|
||||
elif isinstance(config, list):
|
||||
return [cls.resolve_env_variables(item) for item in config]
|
||||
return config
|
||||
|
||||
|
||||
_app_config: AppConfig | None = None
|
||||
_app_config_path: Path | None = None
|
||||
_app_config_mtime: float | None = None
|
||||
_app_config_is_custom = False
|
||||
_current_app_config: ContextVar[AppConfig | None] = ContextVar("deerflow_current_app_config", default=None)
|
||||
_current_app_config_stack: ContextVar[tuple[AppConfig | None, ...]] = ContextVar("deerflow_current_app_config_stack", default=())
|
||||
|
||||
|
||||
def _get_config_mtime(config_path: Path) -> float | None:
|
||||
"""Get the modification time of a config file if it exists."""
|
||||
try:
|
||||
return config_path.stat().st_mtime
|
||||
except OSError:
|
||||
return None
|
||||
|
||||
|
||||
def _load_and_cache_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Load config from disk and refresh cache metadata."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path(config_path)
|
||||
_app_config = AppConfig.from_file(str(resolved_path))
|
||||
_app_config_path = resolved_path
|
||||
_app_config_mtime = _get_config_mtime(resolved_path)
|
||||
_app_config_is_custom = False
|
||||
return _app_config
|
||||
|
||||
|
||||
def get_app_config() -> AppConfig:
|
||||
"""Get the DeerFlow config instance.
|
||||
|
||||
Returns a cached singleton instance and automatically reloads it when the
|
||||
underlying config file path or modification time changes. Use
|
||||
`reload_app_config()` to force a reload, or `reset_app_config()` to clear
|
||||
the cache.
|
||||
"""
|
||||
global _app_config, _app_config_path, _app_config_mtime
|
||||
|
||||
runtime_override = _current_app_config.get()
|
||||
if runtime_override is not None:
|
||||
return runtime_override
|
||||
|
||||
if _app_config is not None and _app_config_is_custom:
|
||||
return _app_config
|
||||
|
||||
resolved_path = AppConfig.resolve_config_path()
|
||||
current_mtime = _get_config_mtime(resolved_path)
|
||||
|
||||
should_reload = _app_config is None or _app_config_path != resolved_path or _app_config_mtime != current_mtime
|
||||
if should_reload:
|
||||
if _app_config_path == resolved_path and _app_config_mtime is not None and current_mtime is not None and _app_config_mtime != current_mtime:
|
||||
logger.info(
|
||||
"Config file has been modified (mtime: %s -> %s), reloading AppConfig",
|
||||
_app_config_mtime,
|
||||
current_mtime,
|
||||
)
|
||||
_load_and_cache_app_config(str(resolved_path))
|
||||
return _app_config
|
||||
|
||||
|
||||
def reload_app_config(config_path: str | None = None) -> AppConfig:
|
||||
"""Force reload from file and update the cache."""
|
||||
return _load_and_cache_app_config(config_path)
|
||||
|
||||
|
||||
def reset_app_config() -> None:
|
||||
"""Clear the cache so the next `get_app_config()` reloads from file."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = None
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = False
|
||||
|
||||
|
||||
def set_app_config(config: AppConfig) -> None:
|
||||
"""Inject a config instance directly, bypassing file loading (for testing)."""
|
||||
global _app_config, _app_config_path, _app_config_mtime, _app_config_is_custom
|
||||
_app_config = config
|
||||
_app_config_path = None
|
||||
_app_config_mtime = None
|
||||
_app_config_is_custom = True
|
||||
|
||||
|
||||
def peek_current_app_config() -> AppConfig | None:
|
||||
"""Return the runtime-scoped AppConfig override, if one is active."""
|
||||
return _current_app_config.get()
|
||||
|
||||
|
||||
def push_current_app_config(config: AppConfig) -> None:
|
||||
"""Push a runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
_current_app_config_stack.set(stack + (_current_app_config.get(),))
|
||||
_current_app_config.set(config)
|
||||
|
||||
|
||||
def pop_current_app_config() -> None:
|
||||
"""Pop the latest runtime-scoped AppConfig override for the current execution context."""
|
||||
stack = _current_app_config_stack.get()
|
||||
if not stack:
|
||||
_current_app_config.set(None)
|
||||
return
|
||||
previous = stack[-1]
|
||||
_current_app_config_stack.set(stack[:-1])
|
||||
_current_app_config.set(previous)
|
||||
@@ -0,0 +1,69 @@
|
||||
"""Unified storage backend configuration for checkpointer and application data.
|
||||
|
||||
SQLite: checkpointer → {sqlite_dir}/checkpoints.db, app → {sqlite_dir}/deerflow.db
|
||||
(separate files to avoid write-lock contention)
|
||||
Postgres: shared URL, independent connection pools per layer.
|
||||
|
||||
Sensitive values use $VAR syntax resolved by AppConfig.resolve_env_variables()
|
||||
before this config is instantiated.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
def _strip_legacy_state_prefix(path: str) -> str:
|
||||
"""Keep old .deer-flow/* config values compatible with Paths.base_dir."""
|
||||
prefix = ".deer-flow/"
|
||||
if path == ".deer-flow":
|
||||
return "."
|
||||
if path.startswith(prefix):
|
||||
return path[len(prefix) :]
|
||||
return path
|
||||
|
||||
|
||||
class StorageConfig(BaseModel):
|
||||
driver: Literal["mysql", "sqlite", "postgres", "postgresql"] = Field(
|
||||
default="sqlite",
|
||||
description="Storage driver for both checkpointer and application data. 'sqlite' for single-node deployment (default),'postgres' for production multi-node deployment, 'mysql' for MySQL databases.",
|
||||
)
|
||||
sqlite_dir: str = Field(
|
||||
default=".deer-flow/data",
|
||||
description="Directory for SQLite .db files (sqlite driver only).",
|
||||
)
|
||||
username: str = Field(default="", description="db username ")
|
||||
password: str = Field(default="", description="db password. Use $VAR syntax in config.yaml to read from .env.")
|
||||
host: str = Field(default="localhost", description="db host.")
|
||||
port: int = Field(default=5432, description="db port.")
|
||||
db_name: str = Field(default="deerflow", description="db database name.")
|
||||
database_url: str = Field(default="", description="Complete SQLAlchemy database URL. Takes precedence for non-SQLite drivers.")
|
||||
sqlite_db_path: str = Field(default=".deer-flow/data", description="Directory for SQLite .db files (sqlite driver only).")
|
||||
echo_sql: bool = Field(default=False, description="Log all SQL statements (debug only).")
|
||||
pool_size: int = Field(default=5, description="Connection pool size per layer.")
|
||||
|
||||
# -- Derived helpers (not user-configured) --
|
||||
|
||||
@property
|
||||
def _resolved_sqlite_dir(self) -> str:
|
||||
"""Resolve sqlite_dir to an absolute path under DeerFlow's base dir."""
|
||||
from pathlib import Path
|
||||
|
||||
path = Path(self.sqlite_dir)
|
||||
if path.is_absolute():
|
||||
return str(path.resolve())
|
||||
|
||||
try:
|
||||
from deerflow.config.paths import resolve_path
|
||||
|
||||
return str(resolve_path(_strip_legacy_state_prefix(self.sqlite_dir)))
|
||||
except ImportError:
|
||||
return str(path.resolve())
|
||||
|
||||
@property
|
||||
def sqlite_storage_path(self) -> str:
|
||||
"""SQLite file path for storage-owned app data and checkpointer."""
|
||||
return os.path.join(self._resolved_sqlite_dir, "deerflow.db")
|
||||
@@ -0,0 +1,32 @@
|
||||
from store.persistence.base_model import (
|
||||
Base,
|
||||
DataClassBase,
|
||||
DateTimeMixin,
|
||||
MappedBase,
|
||||
TimeZone,
|
||||
UniversalText,
|
||||
id_key,
|
||||
)
|
||||
|
||||
from .factory import (
|
||||
create_persistence,
|
||||
create_persistence_from_database_config,
|
||||
create_persistence_from_storage_config,
|
||||
storage_config_from_database_config,
|
||||
)
|
||||
from .types import AppPersistence
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"DataClassBase",
|
||||
"DateTimeMixin",
|
||||
"MappedBase",
|
||||
"TimeZone",
|
||||
"UniversalText",
|
||||
"id_key",
|
||||
"create_persistence",
|
||||
"create_persistence_from_database_config",
|
||||
"create_persistence_from_storage_config",
|
||||
"storage_config_from_database_config",
|
||||
"AppPersistence",
|
||||
]
|
||||
@@ -0,0 +1,111 @@
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
|
||||
from sqlalchemy import BigInteger, DateTime, Integer, Text, TypeDecorator
|
||||
from sqlalchemy.dialects.mysql import LONGTEXT
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, MappedAsDataclass, declared_attr, mapped_column
|
||||
|
||||
from store.utils import get_timezone
|
||||
|
||||
|
||||
def current_time() -> datetime:
|
||||
return get_timezone().now()
|
||||
|
||||
|
||||
id_key = Annotated[
|
||||
int,
|
||||
mapped_column(
|
||||
BigInteger().with_variant(Integer, "sqlite"),
|
||||
primary_key=True,
|
||||
unique=True,
|
||||
index=True,
|
||||
autoincrement=True,
|
||||
sort_order=-999,
|
||||
comment="Primary key ID",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class UniversalText(TypeDecorator[str]):
|
||||
"""Cross-dialect long text type (LONGTEXT on MySQL, Text on PostgreSQL)."""
|
||||
|
||||
impl = Text
|
||||
cache_ok = True
|
||||
|
||||
def load_dialect_impl(self, dialect): # noqa: ANN001
|
||||
if dialect.name == "mysql":
|
||||
return dialect.type_descriptor(LONGTEXT())
|
||||
return dialect.type_descriptor(Text())
|
||||
|
||||
def process_bind_param(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: str | None, dialect) -> str | None: # noqa: ANN001
|
||||
return value
|
||||
|
||||
|
||||
class TimeZone(TypeDecorator[datetime]):
|
||||
"""Timezone-aware datetime type compatible with PostgreSQL and MySQL."""
|
||||
|
||||
impl = DateTime(timezone=True)
|
||||
cache_ok = True
|
||||
|
||||
@property
|
||||
def python_type(self) -> type[datetime]:
|
||||
return datetime
|
||||
|
||||
def process_bind_param(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
timezone = get_timezone()
|
||||
if value is not None and value.utcoffset() != timezone.now().utcoffset():
|
||||
value = timezone.from_datetime(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value: datetime | None, dialect) -> datetime | None: # noqa: ANN001
|
||||
timezone = get_timezone()
|
||||
if value is not None and value.tzinfo is None:
|
||||
value = value.replace(tzinfo=timezone.tz_info)
|
||||
return value
|
||||
|
||||
|
||||
class DateTimeMixin(MappedAsDataclass):
|
||||
"""Mixin that adds created_time / updated_time columns."""
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
|
||||
|
||||
class MappedBase(AsyncAttrs, DeclarativeBase):
|
||||
"""Async-capable declarative base for all ORM models."""
|
||||
|
||||
@declared_attr.directive
|
||||
def __tablename__(self) -> str:
|
||||
return self.__name__.lower()
|
||||
|
||||
@declared_attr.directive
|
||||
def __table_args__(self) -> dict:
|
||||
return {"comment": self.__doc__ or ""}
|
||||
|
||||
|
||||
class DataClassBase(MappedAsDataclass, MappedBase):
|
||||
"""Declarative base with native dataclass integration."""
|
||||
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
class Base(DataClassBase, DateTimeMixin):
|
||||
"""Declarative dataclass base with created_time / updated_time columns."""
|
||||
|
||||
__abstract__ = True
|
||||
@@ -0,0 +1,9 @@
|
||||
from .mysql import build_mysql_persistence
|
||||
from .postgres import build_postgres_persistence
|
||||
from .sqlite import build_sqlite_persistence
|
||||
|
||||
__all__ = [
|
||||
"build_postgres_persistence",
|
||||
"build_mysql_persistence",
|
||||
"build_sqlite_persistence",
|
||||
]
|
||||
@@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine import make_url
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _validate_mysql_driver(db_url: URL) -> str:
|
||||
url = make_url(db_url)
|
||||
driver = url.get_driver_name()
|
||||
|
||||
if driver not in {"aiomysql", "asyncmy"}:
|
||||
raise ValueError(f"MySQL persistence requires async SQLAlchemy driver (aiomysql/asyncmy), got: {driver!r}")
|
||||
return driver
|
||||
|
||||
|
||||
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||
return db_url.render_as_string(hide_password=False)
|
||||
|
||||
|
||||
async def build_mysql_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
_validate_mysql_driver(db_url)
|
||||
|
||||
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AIOMySQLSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def _checkpoint_conn_string(db_url: URL) -> str:
|
||||
return db_url.set(drivername="postgresql").render_as_string(hide_password=False)
|
||||
|
||||
|
||||
async def build_postgres_persistence(db_url: URL, *, echo: bool = False, pool_size: int = 5) -> AppPersistence:
|
||||
from langgraph.checkpoint.postgres.aio import AsyncPostgresSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
pool_pre_ping=True,
|
||||
pool_size=pool_size,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncPostgresSaver.from_conn_string(_checkpoint_conn_string(db_url))
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables / migrations
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,68 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from sqlalchemy import URL, event
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
|
||||
from store.persistence import MappedBase
|
||||
from store.persistence.shared import close_in_order
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
async def build_sqlite_persistence(db_url: URL, *, echo: bool = False) -> AppPersistence:
|
||||
from langgraph.checkpoint.sqlite.aio import AsyncSqliteSaver
|
||||
|
||||
import store.repositories.models # noqa: F401
|
||||
|
||||
engine = create_async_engine(
|
||||
db_url,
|
||||
echo=echo,
|
||||
future=True,
|
||||
json_serializer=lambda obj: json.dumps(obj, ensure_ascii=False),
|
||||
)
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def _enable_sqlite_pragmas(dbapi_conn, _record): # noqa: ANN001
|
||||
cursor = dbapi_conn.cursor()
|
||||
try:
|
||||
cursor.execute("PRAGMA journal_mode=WAL;")
|
||||
cursor.execute("PRAGMA synchronous=NORMAL;")
|
||||
cursor.execute("PRAGMA foreign_keys=ON;")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
session_factory = async_sessionmaker(
|
||||
bind=engine,
|
||||
class_=AsyncSession,
|
||||
expire_on_commit=False,
|
||||
autoflush=False,
|
||||
)
|
||||
|
||||
saver_cm = AsyncSqliteSaver.from_conn_string(db_url.database)
|
||||
checkpointer = await saver_cm.__aenter__()
|
||||
|
||||
async def setup() -> None:
|
||||
# 1. LangGraph checkpoint tables
|
||||
await checkpointer.setup()
|
||||
|
||||
# 2. ORM business tables
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(MappedBase.metadata.create_all)
|
||||
|
||||
async def _close_saver() -> None:
|
||||
await saver_cm.__aexit__(None, None, None)
|
||||
|
||||
async def aclose() -> None:
|
||||
await close_in_order(
|
||||
engine.dispose,
|
||||
_close_saver,
|
||||
)
|
||||
|
||||
return AppPersistence(
|
||||
checkpointer=checkpointer,
|
||||
engine=engine,
|
||||
session_factory=session_factory,
|
||||
setup=setup,
|
||||
aclose=aclose,
|
||||
)
|
||||
@@ -0,0 +1,123 @@
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import URL
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from store.common import DataBaseType
|
||||
from store.config.app_config import get_app_config
|
||||
from store.config.storage_config import StorageConfig
|
||||
from store.persistence.types import AppPersistence
|
||||
|
||||
|
||||
def storage_config_from_database_config(database_config: Any) -> StorageConfig:
|
||||
"""Convert the existing public DatabaseConfig shape to StorageConfig.
|
||||
|
||||
Storage only owns durable database-backed persistence. The app bridge
|
||||
should handle memory mode before calling into this package.
|
||||
"""
|
||||
backend = getattr(database_config, "backend", None)
|
||||
if backend == "sqlite":
|
||||
return StorageConfig(
|
||||
driver="sqlite",
|
||||
sqlite_dir=getattr(database_config, "sqlite_dir", ".deer-flow/data"),
|
||||
echo_sql=getattr(database_config, "echo_sql", False),
|
||||
pool_size=getattr(database_config, "pool_size", 5),
|
||||
)
|
||||
|
||||
if backend == "postgres":
|
||||
postgres_url = getattr(database_config, "postgres_url", "")
|
||||
if not postgres_url:
|
||||
raise ValueError("database.postgres_url is required when database.backend is 'postgres'")
|
||||
parsed = make_url(postgres_url)
|
||||
return StorageConfig(
|
||||
driver="postgres",
|
||||
database_url=postgres_url,
|
||||
username=parsed.username or "",
|
||||
password=parsed.password or "",
|
||||
host=parsed.host or "localhost",
|
||||
port=parsed.port or 5432,
|
||||
db_name=parsed.database or "deerflow",
|
||||
echo_sql=getattr(database_config, "echo_sql", False),
|
||||
pool_size=getattr(database_config, "pool_size", 5),
|
||||
)
|
||||
|
||||
raise ValueError(f"Unsupported database backend for storage persistence: {backend!r}")
|
||||
|
||||
|
||||
def _create_database_url(storage_config: StorageConfig) -> URL:
|
||||
"""Build an async SQLAlchemy URL from StorageConfig (sqlite/mysql/postgres)."""
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
driver = "sqlite+aiosqlite"
|
||||
elif storage_config.driver == DataBaseType.mysql:
|
||||
driver = "mysql+aiomysql"
|
||||
elif storage_config.driver in (DataBaseType.postgresql, "postgres"):
|
||||
driver = "postgresql+asyncpg"
|
||||
else:
|
||||
raise ValueError(f"Unsupported database driver: {storage_config.driver}")
|
||||
|
||||
if storage_config.driver == DataBaseType.sqlite:
|
||||
import os
|
||||
|
||||
db_path = storage_config.sqlite_storage_path
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
database=db_path,
|
||||
)
|
||||
elif storage_config.database_url:
|
||||
url = make_url(storage_config.database_url)
|
||||
if storage_config.driver in (DataBaseType.postgresql, "postgres") and url.drivername == "postgresql":
|
||||
url = url.set(drivername="postgresql+asyncpg")
|
||||
elif storage_config.driver == DataBaseType.mysql and url.drivername == "mysql":
|
||||
url = url.set(drivername="mysql+aiomysql")
|
||||
else:
|
||||
url = URL.create(
|
||||
drivername=driver,
|
||||
username=storage_config.username,
|
||||
password=storage_config.password,
|
||||
host=storage_config.host,
|
||||
port=storage_config.port,
|
||||
database=storage_config.db_name or "deerflow",
|
||||
)
|
||||
|
||||
return url
|
||||
|
||||
|
||||
async def create_persistence_from_storage_config(storage_config: StorageConfig) -> AppPersistence:
|
||||
from .drivers.mysql import build_mysql_persistence
|
||||
from .drivers.postgres import build_postgres_persistence
|
||||
from .drivers.sqlite import build_sqlite_persistence
|
||||
|
||||
driver = storage_config.driver
|
||||
db_url = _create_database_url(storage_config)
|
||||
|
||||
if driver in ("postgres", "postgresql"):
|
||||
return await build_postgres_persistence(
|
||||
db_url,
|
||||
echo=storage_config.echo_sql,
|
||||
pool_size=storage_config.pool_size,
|
||||
)
|
||||
|
||||
if driver == "mysql":
|
||||
return await build_mysql_persistence(
|
||||
db_url,
|
||||
echo=storage_config.echo_sql,
|
||||
pool_size=storage_config.pool_size,
|
||||
)
|
||||
|
||||
if driver == "sqlite":
|
||||
return await build_sqlite_persistence(db_url, echo=storage_config.echo_sql)
|
||||
|
||||
raise ValueError(f"Unsupported database driver: {driver}")
|
||||
|
||||
|
||||
async def create_persistence_from_database_config(database_config: Any) -> AppPersistence:
|
||||
storage_config = storage_config_from_database_config(database_config)
|
||||
return await create_persistence_from_storage_config(storage_config)
|
||||
|
||||
|
||||
async def create_persistence() -> AppPersistence:
|
||||
app_config = get_app_config()
|
||||
return await create_persistence_from_storage_config(app_config.storage)
|
||||
@@ -0,0 +1,189 @@
|
||||
"""Dialect-aware JSON value matching for storage SQLAlchemy repositories."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import BigInteger, Float, String, bindparam
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.compiler import SQLCompiler
|
||||
from sqlalchemy.sql.expression import ColumnElement
|
||||
from sqlalchemy.sql.visitors import InternalTraversal
|
||||
from sqlalchemy.types import Boolean, TypeEngine
|
||||
|
||||
_KEY_CHARSET_RE = re.compile(r"^[A-Za-z0-9_\-]+$")
|
||||
ALLOWED_FILTER_VALUE_TYPES: tuple[type, ...] = (type(None), bool, int, float, str)
|
||||
|
||||
_INT64_MIN = -(2**63)
|
||||
_INT64_MAX = 2**63 - 1
|
||||
|
||||
|
||||
def validate_metadata_filter_key(key: object) -> bool:
|
||||
"""Return True when *key* is safe for JSON metadata filter SQL paths."""
|
||||
return isinstance(key, str) and bool(_KEY_CHARSET_RE.match(key))
|
||||
|
||||
|
||||
def validate_metadata_filter_value(value: object) -> bool:
|
||||
"""Return True when *value* can be compiled into a portable JSON predicate."""
|
||||
if not isinstance(value, ALLOWED_FILTER_VALUE_TYPES):
|
||||
return False
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
return _INT64_MIN <= value <= _INT64_MAX
|
||||
return True
|
||||
|
||||
|
||||
class JsonMatch(ColumnElement[bool]):
|
||||
"""Dialect-portable ``column[key] == value`` for JSON columns."""
|
||||
|
||||
inherit_cache = True
|
||||
type = Boolean()
|
||||
_is_implicitly_boolean = True
|
||||
|
||||
_traverse_internals = [
|
||||
("column", InternalTraversal.dp_clauseelement),
|
||||
("key", InternalTraversal.dp_string),
|
||||
("value", InternalTraversal.dp_plain_obj),
|
||||
("value_type", InternalTraversal.dp_string),
|
||||
]
|
||||
|
||||
def __init__(self, column: ColumnElement[Any], key: str, value: object) -> None:
|
||||
if not validate_metadata_filter_key(key):
|
||||
raise ValueError(f"JsonMatch key must match {_KEY_CHARSET_RE.pattern!r}; got: {key!r}")
|
||||
if not validate_metadata_filter_value(value):
|
||||
if isinstance(value, int) and not isinstance(value, bool):
|
||||
raise TypeError(f"JsonMatch int value out of signed 64-bit range [-2**63, 2**63-1]: {value!r}")
|
||||
raise TypeError(f"JsonMatch value must be None, bool, int, float, or str; got: {type(value).__name__!r}")
|
||||
self.column = column
|
||||
self.key = key
|
||||
self.value = value
|
||||
self.value_type = type(value).__qualname__
|
||||
super().__init__()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _Dialect:
|
||||
null_type: str
|
||||
num_types: tuple[str, ...]
|
||||
num_cast: str
|
||||
int_types: tuple[str, ...]
|
||||
int_cast: str
|
||||
int_guard: str | None
|
||||
string_type: str
|
||||
bool_type: str | None
|
||||
true_value: str
|
||||
false_value: str
|
||||
|
||||
|
||||
_SQLITE = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("integer", "real"),
|
||||
num_cast="REAL",
|
||||
int_types=("integer",),
|
||||
int_cast="INTEGER",
|
||||
int_guard=None,
|
||||
string_type="text",
|
||||
bool_type=None,
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
_POSTGRES = _Dialect(
|
||||
null_type="null",
|
||||
num_types=("number",),
|
||||
num_cast="DOUBLE PRECISION",
|
||||
int_types=("number",),
|
||||
int_cast="BIGINT",
|
||||
int_guard="'^-?[0-9]+$'",
|
||||
string_type="string",
|
||||
bool_type="boolean",
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
_MYSQL = _Dialect(
|
||||
null_type="NULL",
|
||||
num_types=("INTEGER", "DOUBLE", "DECIMAL"),
|
||||
num_cast="DOUBLE",
|
||||
int_types=("INTEGER",),
|
||||
int_cast="SIGNED",
|
||||
int_guard=None,
|
||||
string_type="STRING",
|
||||
bool_type="BOOLEAN",
|
||||
true_value="true",
|
||||
false_value="false",
|
||||
)
|
||||
|
||||
|
||||
def _bind(compiler: SQLCompiler, value: object, sa_type: TypeEngine[Any], **kw: Any) -> str:
|
||||
param = bindparam(None, value, type_=sa_type)
|
||||
return compiler.process(param, **kw)
|
||||
|
||||
|
||||
def _type_check(typeof: str, types: tuple[str, ...]) -> str:
|
||||
if len(types) == 1:
|
||||
return f"{typeof} = '{types[0]}'"
|
||||
quoted = ", ".join(f"'{type_name}'" for type_name in types)
|
||||
return f"{typeof} IN ({quoted})"
|
||||
|
||||
|
||||
def _build_clause(compiler: SQLCompiler, typeof: str, extract: str, value: object, dialect: _Dialect, **kw: Any) -> str:
|
||||
if value is None:
|
||||
return f"{typeof} = '{dialect.null_type}'"
|
||||
if isinstance(value, bool):
|
||||
bool_str = dialect.true_value if value else dialect.false_value
|
||||
if dialect.bool_type is None:
|
||||
return f"{typeof} = '{bool_str}'"
|
||||
return f"({typeof} = '{dialect.bool_type}' AND {extract} = '{bool_str}')"
|
||||
if isinstance(value, int):
|
||||
bp = _bind(compiler, value, BigInteger(), **kw)
|
||||
if dialect.int_guard:
|
||||
return f"(CASE WHEN {_type_check(typeof, dialect.int_types)} AND {extract} ~ {dialect.int_guard} THEN CAST({extract} AS {dialect.int_cast}) END = {bp})"
|
||||
return f"({_type_check(typeof, dialect.int_types)} AND CAST({extract} AS {dialect.int_cast}) = {bp})"
|
||||
if isinstance(value, float):
|
||||
bp = _bind(compiler, value, Float(), **kw)
|
||||
return f"({_type_check(typeof, dialect.num_types)} AND CAST({extract} AS {dialect.num_cast}) = {bp})"
|
||||
bp = _bind(compiler, str(value), String(), **kw)
|
||||
return f"({typeof} = '{dialect.string_type}' AND {extract} = {bp})"
|
||||
|
||||
|
||||
@compiles(JsonMatch, "sqlite")
|
||||
def _compile_sqlite(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
path = f'$."{element.key}"'
|
||||
typeof = f"json_type({col}, '{path}')"
|
||||
extract = f"json_extract({col}, '{path}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _SQLITE, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch, "postgresql")
|
||||
def _compile_postgres(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
typeof = f"json_typeof({col} -> '{element.key}')"
|
||||
extract = f"({col} ->> '{element.key}')"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _POSTGRES, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch, "mysql")
|
||||
def _compile_mysql(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
if not validate_metadata_filter_key(element.key):
|
||||
raise ValueError(f"Key escaped validation: {element.key!r}")
|
||||
col = compiler.process(element.column, **kw)
|
||||
path = f'$."{element.key}"'
|
||||
typeof = f"JSON_TYPE(JSON_EXTRACT({col}, '{path}'))"
|
||||
extract = f"JSON_UNQUOTE(JSON_EXTRACT({col}, '{path}'))"
|
||||
return _build_clause(compiler, typeof, extract, element.value, _MYSQL, **kw)
|
||||
|
||||
|
||||
@compiles(JsonMatch)
|
||||
def _compile_default(element: JsonMatch, compiler: SQLCompiler, **kw: Any) -> str:
|
||||
raise NotImplementedError(f"JsonMatch supports sqlite, postgresql, and mysql; got dialect: {compiler.dialect.name}")
|
||||
|
||||
|
||||
def json_match(column: ColumnElement[Any], key: str, value: object) -> JsonMatch:
|
||||
return JsonMatch(column, key, value)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .close import close_in_order
|
||||
|
||||
__all__ = ["close_in_order"]
|
||||
@@ -0,0 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
AsyncCloser = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
async def close_in_order(*closers: AsyncCloser) -> None:
|
||||
"""
|
||||
Run async closers in order and raise the first error, if any.
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Used to keep driver-specific close logic readable.
|
||||
- We intentionally do not stop at first failure, so later resources
|
||||
still get a chance to close.
|
||||
"""
|
||||
first_error: Exception | None = None
|
||||
|
||||
for closer in closers:
|
||||
try:
|
||||
await closer()
|
||||
except Exception as exc:
|
||||
if first_error is None:
|
||||
first_error = exc
|
||||
|
||||
if first_error is not None:
|
||||
raise first_error
|
||||
@@ -0,0 +1,23 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langgraph.types import Checkpointer
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker
|
||||
|
||||
AsyncSetup = Callable[[], Awaitable[None]]
|
||||
AsyncClose = Callable[[], Awaitable[None]]
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class AppPersistence:
|
||||
"""
|
||||
Unified runtime persistence bundle.
|
||||
"""
|
||||
|
||||
checkpointer: Checkpointer
|
||||
engine: AsyncEngine
|
||||
session_factory: async_sessionmaker[AsyncSession]
|
||||
setup: AsyncSetup
|
||||
aclose: AsyncClose
|
||||
@@ -0,0 +1,53 @@
|
||||
from store.repositories.contracts import (
|
||||
Feedback,
|
||||
FeedbackAggregate,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
InvalidMetadataFilterError,
|
||||
Run,
|
||||
RunCreate,
|
||||
RunEvent,
|
||||
RunEventCreate,
|
||||
RunEventRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
User,
|
||||
UserCreate,
|
||||
UserNotFoundError,
|
||||
UserRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.factory import (
|
||||
build_feedback_repository,
|
||||
build_run_event_repository,
|
||||
build_run_repository,
|
||||
build_thread_meta_repository,
|
||||
build_user_repository,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Feedback",
|
||||
"FeedbackAggregate",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"InvalidMetadataFilterError",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
"User",
|
||||
"UserCreate",
|
||||
"UserNotFoundError",
|
||||
"UserRepositoryProtocol",
|
||||
"build_run_repository",
|
||||
"build_run_event_repository",
|
||||
"build_thread_meta_repository",
|
||||
"build_feedback_repository",
|
||||
"build_user_repository",
|
||||
]
|
||||
@@ -0,0 +1,49 @@
|
||||
from store.repositories.contracts.feedback import (
|
||||
Feedback,
|
||||
FeedbackAggregate,
|
||||
FeedbackCreate,
|
||||
FeedbackRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.run import (
|
||||
Run,
|
||||
RunCreate,
|
||||
RunRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.run_event import (
|
||||
RunEvent,
|
||||
RunEventCreate,
|
||||
RunEventRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
InvalidMetadataFilterError,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.contracts.user import (
|
||||
User,
|
||||
UserCreate,
|
||||
UserNotFoundError,
|
||||
UserRepositoryProtocol,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Feedback",
|
||||
"FeedbackAggregate",
|
||||
"FeedbackCreate",
|
||||
"FeedbackRepositoryProtocol",
|
||||
"Run",
|
||||
"RunCreate",
|
||||
"RunEvent",
|
||||
"RunEventCreate",
|
||||
"RunEventRepositoryProtocol",
|
||||
"RunRepositoryProtocol",
|
||||
"InvalidMetadataFilterError",
|
||||
"ThreadMeta",
|
||||
"ThreadMetaCreate",
|
||||
"ThreadMetaRepositoryProtocol",
|
||||
"User",
|
||||
"UserCreate",
|
||||
"UserNotFoundError",
|
||||
"UserRepositoryProtocol",
|
||||
]
|
||||
@@ -0,0 +1,77 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Protocol, TypedDict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class FeedbackCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
rating: int
|
||||
user_id: str | None = None
|
||||
message_id: str | None = None
|
||||
comment: str | None = None
|
||||
|
||||
|
||||
class Feedback(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
feedback_id: str
|
||||
run_id: str
|
||||
thread_id: str
|
||||
rating: int
|
||||
user_id: str | None
|
||||
message_id: str | None
|
||||
comment: str | None
|
||||
created_time: datetime
|
||||
|
||||
|
||||
class FeedbackAggregate(TypedDict):
|
||||
run_id: str
|
||||
total: int
|
||||
positive: int
|
||||
negative: int
|
||||
|
||||
|
||||
class FeedbackRepositoryProtocol(Protocol):
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
pass
|
||||
|
||||
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
pass
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
pass
|
||||
|
||||
async def list_feedback_by_run(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
thread_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
pass
|
||||
|
||||
async def list_feedback_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
pass
|
||||
|
||||
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||
pass
|
||||
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||
pass
|
||||
|
||||
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
|
||||
pass
|
||||
@@ -0,0 +1,100 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RunCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
status: str = "pending"
|
||||
model_name: str | None = None
|
||||
multitask_strategy: str = "reject"
|
||||
error: str | None = None
|
||||
follow_up_to_run_id: str | None = None
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
kwargs: dict[str, Any] = Field(default_factory=dict)
|
||||
created_time: datetime | None = None
|
||||
|
||||
|
||||
class Run(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
run_id: str
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
status: str
|
||||
model_name: str | None
|
||||
multitask_strategy: str
|
||||
error: str | None
|
||||
follow_up_to_run_id: str | None
|
||||
metadata: dict[str, Any]
|
||||
kwargs: dict[str, Any]
|
||||
total_input_tokens: int
|
||||
total_output_tokens: int
|
||||
total_tokens: int
|
||||
llm_call_count: int
|
||||
lead_agent_tokens: int
|
||||
subagent_tokens: int
|
||||
middleware_tokens: int
|
||||
message_count: int
|
||||
first_human_message: str | None
|
||||
last_ai_message: str | None
|
||||
created_time: datetime
|
||||
updated_time: datetime | None
|
||||
|
||||
|
||||
class RunRepositoryProtocol(Protocol):
|
||||
async def create_run(self, data: RunCreate) -> Run:
|
||||
pass
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
pass
|
||||
|
||||
async def list_runs_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[Run]:
|
||||
pass
|
||||
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
|
||||
pass
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
|
||||
pass
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
pass
|
||||
@@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class RunEventCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str
|
||||
run_id: str
|
||||
user_id: str | None = None
|
||||
event_type: str
|
||||
category: str
|
||||
content: Any = ""
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
created_at: datetime | None = None
|
||||
|
||||
|
||||
class RunEvent(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
thread_id: str
|
||||
run_id: str
|
||||
user_id: str | None
|
||||
event_type: str
|
||||
category: str
|
||||
content: Any
|
||||
metadata: dict[str, Any]
|
||||
seq: int
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class RunEventRepositoryProtocol(Protocol):
|
||||
# Sequence values are time-ordered integer cursors. The application layer
|
||||
# owns the single-writer invariant for a thread while a run is active.
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
pass
|
||||
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
pass
|
||||
|
||||
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
pass
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
|
||||
pass
|
||||
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class InvalidMetadataFilterError(ValueError):
|
||||
"""Raised when all client-supplied metadata filters are rejected."""
|
||||
|
||||
|
||||
class ThreadMetaCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
thread_id: str
|
||||
assistant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
display_name: str | None = None
|
||||
status: str = "idle"
|
||||
metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ThreadMeta(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
thread_id: str
|
||||
assistant_id: str | None
|
||||
user_id: str | None
|
||||
display_name: str | None
|
||||
status: str
|
||||
metadata: dict[str, Any]
|
||||
created_time: datetime
|
||||
updated_time: datetime | None
|
||||
|
||||
|
||||
class ThreadMetaRepositoryProtocol(Protocol):
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
pass
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
pass
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
pass
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
pass
|
||||
@@ -0,0 +1,64 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class UserNotFoundError(LookupError):
|
||||
"""Raised when an update targets a user row that no longer exists."""
|
||||
|
||||
|
||||
class UserCreate(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
id: str
|
||||
email: str
|
||||
password_hash: str | None = None
|
||||
system_role: Literal["admin", "user"] = "user"
|
||||
created_at: datetime | None = None
|
||||
oauth_provider: str | None = None
|
||||
oauth_id: str | None = None
|
||||
needs_setup: bool = False
|
||||
token_version: int = 0
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
id: str
|
||||
email: str
|
||||
password_hash: str | None
|
||||
system_role: Literal["admin", "user"]
|
||||
created_at: datetime
|
||||
oauth_provider: str | None
|
||||
oauth_id: str | None
|
||||
needs_setup: bool
|
||||
token_version: int
|
||||
|
||||
|
||||
class UserRepositoryProtocol(Protocol):
|
||||
async def create_user(self, data: UserCreate) -> User:
|
||||
pass
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
pass
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
pass
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
pass
|
||||
|
||||
async def get_first_admin(self) -> User | None:
|
||||
pass
|
||||
|
||||
async def update_user(self, data: User) -> User:
|
||||
pass
|
||||
|
||||
async def count_users(self) -> int:
|
||||
pass
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
pass
|
||||
@@ -0,0 +1,13 @@
|
||||
from store.repositories.db.feedback import DbFeedbackRepository
|
||||
from store.repositories.db.run import DbRunRepository
|
||||
from store.repositories.db.run_event import DbRunEventRepository
|
||||
from store.repositories.db.thread_meta import DbThreadMetaRepository
|
||||
from store.repositories.db.user import DbUserRepository
|
||||
|
||||
__all__ = [
|
||||
"DbFeedbackRepository",
|
||||
"DbRunRepository",
|
||||
"DbRunEventRepository",
|
||||
"DbThreadMetaRepository",
|
||||
"DbUserRepository",
|
||||
]
|
||||
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from sqlalchemy import case, delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.feedback import Feedback, FeedbackAggregate, FeedbackCreate, FeedbackRepositoryProtocol
|
||||
from store.repositories.models.feedback import Feedback as FeedbackModel
|
||||
|
||||
|
||||
def _to_feedback(m: FeedbackModel) -> Feedback:
|
||||
return Feedback(
|
||||
feedback_id=m.feedback_id,
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
rating=m.rating,
|
||||
user_id=m.user_id,
|
||||
message_id=m.message_id,
|
||||
comment=m.comment,
|
||||
created_time=m.created_time,
|
||||
)
|
||||
|
||||
|
||||
class DbFeedbackRepository(FeedbackRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
if data.rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||
model = FeedbackModel(
|
||||
feedback_id=data.feedback_id,
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
rating=data.rating,
|
||||
user_id=data.user_id,
|
||||
message_id=data.message_id,
|
||||
comment=data.comment,
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_feedback(model)
|
||||
|
||||
async def upsert_feedback(self, data: FeedbackCreate) -> Feedback:
|
||||
if data.rating not in (1, -1):
|
||||
raise ValueError(f"rating must be +1 or -1, got {data.rating}")
|
||||
|
||||
result = await self._session.execute(
|
||||
select(FeedbackModel).where(
|
||||
FeedbackModel.thread_id == data.thread_id,
|
||||
FeedbackModel.run_id == data.run_id,
|
||||
FeedbackModel.user_id == data.user_id,
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
if model is None:
|
||||
return await self.create_feedback(data)
|
||||
|
||||
model.rating = data.rating
|
||||
model.message_id = data.message_id
|
||||
model.comment = data.comment
|
||||
model.created_time = datetime.now(UTC)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_feedback(model)
|
||||
|
||||
async def get_feedback(self, feedback_id: str) -> Feedback | None:
|
||||
result = await self._session.execute(select(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_feedback(model) if model else None
|
||||
|
||||
async def list_feedback_by_run(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
thread_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
stmt = select(FeedbackModel).where(FeedbackModel.run_id == run_id)
|
||||
if thread_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.thread_id == thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||
stmt = stmt.order_by(FeedbackModel.created_time.desc())
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def list_feedback_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> list[Feedback]:
|
||||
stmt = select(FeedbackModel).where(FeedbackModel.thread_id == thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||
stmt = stmt.order_by(FeedbackModel.created_time.desc())
|
||||
if limit is not None:
|
||||
stmt = stmt.limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_feedback(m) for m in result.scalars().all()]
|
||||
|
||||
async def delete_feedback(self, feedback_id: str) -> bool:
|
||||
existing = await self.get_feedback(feedback_id)
|
||||
if existing is None:
|
||||
return False
|
||||
await self._session.execute(delete(FeedbackModel).where(FeedbackModel.feedback_id == feedback_id))
|
||||
return True
|
||||
|
||||
async def delete_feedback_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> bool:
|
||||
stmt = select(FeedbackModel).where(
|
||||
FeedbackModel.thread_id == thread_id,
|
||||
FeedbackModel.run_id == run_id,
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(FeedbackModel.user_id == user_id)
|
||||
result = await self._session.execute(stmt)
|
||||
model = result.scalar_one_or_none()
|
||||
if model is None:
|
||||
return False
|
||||
await self._session.delete(model)
|
||||
return True
|
||||
|
||||
async def aggregate_feedback_by_run(self, thread_id: str, run_id: str) -> FeedbackAggregate:
|
||||
stmt = select(
|
||||
func.count().label("total"),
|
||||
func.coalesce(func.sum(case((FeedbackModel.rating == 1, 1), else_=0)), 0).label("positive"),
|
||||
func.coalesce(func.sum(case((FeedbackModel.rating == -1, 1), else_=0)), 0).label("negative"),
|
||||
).where(FeedbackModel.thread_id == thread_id, FeedbackModel.run_id == run_id)
|
||||
row = (await self._session.execute(stmt)).one()
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"total": int(row.total),
|
||||
"positive": int(row.positive),
|
||||
"negative": int(row.negative),
|
||||
}
|
||||
@@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run import Run, RunCreate, RunRepositoryProtocol
|
||||
from store.repositories.models.run import Run as RunModel
|
||||
|
||||
|
||||
def _to_run(m: RunModel) -> Run:
|
||||
return Run(
|
||||
run_id=m.run_id,
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
status=m.status,
|
||||
model_name=m.model_name,
|
||||
multitask_strategy=m.multitask_strategy,
|
||||
error=m.error,
|
||||
follow_up_to_run_id=m.follow_up_to_run_id,
|
||||
metadata=dict(m.meta or {}),
|
||||
kwargs=dict(m.kwargs or {}),
|
||||
total_input_tokens=m.total_input_tokens,
|
||||
total_output_tokens=m.total_output_tokens,
|
||||
total_tokens=m.total_tokens,
|
||||
llm_call_count=m.llm_call_count,
|
||||
lead_agent_tokens=m.lead_agent_tokens,
|
||||
subagent_tokens=m.subagent_tokens,
|
||||
middleware_tokens=m.middleware_tokens,
|
||||
message_count=m.message_count,
|
||||
first_human_message=m.first_human_message,
|
||||
last_ai_message=m.last_ai_message,
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbRunRepository(RunRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_run(self, data: RunCreate) -> Run:
|
||||
model = RunModel(
|
||||
run_id=data.run_id,
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
status=data.status,
|
||||
model_name=data.model_name,
|
||||
multitask_strategy=data.multitask_strategy,
|
||||
error=data.error,
|
||||
follow_up_to_run_id=data.follow_up_to_run_id,
|
||||
meta=dict(data.metadata),
|
||||
kwargs=dict(data.kwargs),
|
||||
)
|
||||
if data.created_time is not None:
|
||||
model.created_time = data.created_time
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_run(model)
|
||||
|
||||
async def get_run(self, run_id: str) -> Run | None:
|
||||
result = await self._session.execute(select(RunModel).where(RunModel.run_id == run_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_run(model) if model else None
|
||||
|
||||
async def list_runs_by_thread(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
user_id: str | None = None,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> list[Run]:
|
||||
stmt = select(RunModel).where(RunModel.thread_id == thread_id)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunModel.user_id == user_id)
|
||||
stmt = stmt.order_by(RunModel.created_time.desc()).limit(limit).offset(offset)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_status(self, run_id: str, status: str, *, error: str | None = None) -> None:
|
||||
values: dict = {"status": status}
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||
|
||||
async def delete_run(self, run_id: str) -> None:
|
||||
await self._session.execute(delete(RunModel).where(RunModel.run_id == run_id))
|
||||
|
||||
async def list_pending(self, *, before: datetime | str | None = None) -> list[Run]:
|
||||
if before is None:
|
||||
before_dt = datetime.now().astimezone()
|
||||
elif isinstance(before, datetime):
|
||||
before_dt = before
|
||||
else:
|
||||
before_dt = datetime.fromisoformat(before)
|
||||
|
||||
result = await self._session.execute(select(RunModel).where(RunModel.status == "pending", RunModel.created_time <= before_dt).order_by(RunModel.created_time.asc()))
|
||||
return [_to_run(m) for m in result.scalars().all()]
|
||||
|
||||
async def update_run_completion(
|
||||
self,
|
||||
run_id: str,
|
||||
*,
|
||||
status: str,
|
||||
total_input_tokens: int = 0,
|
||||
total_output_tokens: int = 0,
|
||||
total_tokens: int = 0,
|
||||
llm_call_count: int = 0,
|
||||
lead_agent_tokens: int = 0,
|
||||
subagent_tokens: int = 0,
|
||||
middleware_tokens: int = 0,
|
||||
message_count: int = 0,
|
||||
first_human_message: str | None = None,
|
||||
last_ai_message: str | None = None,
|
||||
error: str | None = None,
|
||||
) -> None:
|
||||
values = {
|
||||
"status": status,
|
||||
"total_input_tokens": total_input_tokens,
|
||||
"total_output_tokens": total_output_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"llm_call_count": llm_call_count,
|
||||
"lead_agent_tokens": lead_agent_tokens,
|
||||
"subagent_tokens": subagent_tokens,
|
||||
"middleware_tokens": middleware_tokens,
|
||||
"message_count": message_count,
|
||||
}
|
||||
if first_human_message is not None:
|
||||
values["first_human_message"] = first_human_message[:2000]
|
||||
if last_ai_message is not None:
|
||||
values["last_ai_message"] = last_ai_message[:2000]
|
||||
if error is not None:
|
||||
values["error"] = error
|
||||
await self._session.execute(update(RunModel).where(RunModel.run_id == run_id).values(**values))
|
||||
|
||||
async def aggregate_tokens_by_thread(self, thread_id: str) -> dict[str, Any]:
|
||||
completed = RunModel.status.in_(("success", "error"))
|
||||
model_expr = func.coalesce(RunModel.model_name, "unknown")
|
||||
stmt = (
|
||||
select(
|
||||
model_expr.label("model"),
|
||||
func.count().label("runs"),
|
||||
func.coalesce(func.sum(RunModel.total_tokens), 0).label("total_tokens"),
|
||||
func.coalesce(func.sum(RunModel.total_input_tokens), 0).label("total_input_tokens"),
|
||||
func.coalesce(func.sum(RunModel.total_output_tokens), 0).label("total_output_tokens"),
|
||||
func.coalesce(func.sum(RunModel.lead_agent_tokens), 0).label("lead_agent"),
|
||||
func.coalesce(func.sum(RunModel.subagent_tokens), 0).label("subagent"),
|
||||
func.coalesce(func.sum(RunModel.middleware_tokens), 0).label("middleware"),
|
||||
)
|
||||
.where(RunModel.thread_id == thread_id, completed)
|
||||
.group_by(model_expr)
|
||||
)
|
||||
|
||||
rows = (await self._session.execute(stmt)).all()
|
||||
total_tokens = total_input = total_output = total_runs = 0
|
||||
lead_agent = subagent = middleware = 0
|
||||
by_model: dict[str, dict] = {}
|
||||
for row in rows:
|
||||
by_model[row.model] = {"tokens": row.total_tokens, "runs": row.runs}
|
||||
total_tokens += row.total_tokens
|
||||
total_input += row.total_input_tokens
|
||||
total_output += row.total_output_tokens
|
||||
total_runs += row.runs
|
||||
lead_agent += row.lead_agent
|
||||
subagent += row.subagent
|
||||
middleware += row.middleware
|
||||
|
||||
return {
|
||||
"total_tokens": total_tokens,
|
||||
"total_input_tokens": total_input,
|
||||
"total_output_tokens": total_output,
|
||||
"total_runs": total_runs,
|
||||
"by_model": by_model,
|
||||
"by_caller": {
|
||||
"lead_agent": lead_agent,
|
||||
"subagent": subagent,
|
||||
"middleware": middleware,
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,207 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.run_event import RunEvent, RunEventCreate, RunEventRepositoryProtocol
|
||||
from store.repositories.models.run_event import RunEvent as RunEventModel
|
||||
|
||||
_SEQ_COUNTER_BITS = 12
|
||||
_SEQ_PROCESS_BITS = 9
|
||||
_SEQ_PROCESS_SALT = secrets.randbits(_SEQ_PROCESS_BITS)
|
||||
_SEQ_COUNTER_LIMIT = 1 << _SEQ_COUNTER_BITS
|
||||
_SEQ_TIMESTAMP_SHIFT = _SEQ_COUNTER_BITS + _SEQ_PROCESS_BITS
|
||||
|
||||
|
||||
class _SequenceAllocator:
|
||||
def __init__(self) -> None:
|
||||
self._last_millis = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def allocate_base(self, batch_size: int) -> int:
|
||||
if batch_size >= _SEQ_COUNTER_LIMIT:
|
||||
raise ValueError(f"Run event batch is too large: {batch_size} >= {_SEQ_COUNTER_LIMIT}")
|
||||
|
||||
now_ms = time.time_ns() // 1_000_000
|
||||
with self._lock:
|
||||
seq_ms = max(now_ms, self._last_millis + 1)
|
||||
self._last_millis = seq_ms
|
||||
return (seq_ms << _SEQ_TIMESTAMP_SHIFT) | (_SEQ_PROCESS_SALT << _SEQ_COUNTER_BITS)
|
||||
|
||||
|
||||
_sequence_allocator = _SequenceAllocator()
|
||||
|
||||
|
||||
def _serialize_content(content: Any, metadata: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
||||
if not isinstance(content, str):
|
||||
next_metadata = {**metadata, "content_is_json": True}
|
||||
if isinstance(content, dict):
|
||||
next_metadata["content_is_dict"] = True
|
||||
return json.dumps(content, default=str, ensure_ascii=False), next_metadata
|
||||
return content, metadata
|
||||
|
||||
|
||||
def _deserialize_content(content: str, metadata: dict[str, Any]) -> Any:
|
||||
if not (metadata.get("content_is_json") or metadata.get("content_is_dict")):
|
||||
return content
|
||||
try:
|
||||
return json.loads(content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
|
||||
def _to_run_event(model: RunEventModel) -> RunEvent:
|
||||
raw_metadata = dict(model.meta or {})
|
||||
metadata = {key: value for key, value in raw_metadata.items() if key != "content_is_dict"}
|
||||
return RunEvent(
|
||||
thread_id=model.thread_id,
|
||||
run_id=model.run_id,
|
||||
user_id=model.user_id,
|
||||
event_type=model.event_type,
|
||||
category=model.category,
|
||||
content=_deserialize_content(model.content, raw_metadata),
|
||||
metadata=metadata,
|
||||
seq=model.seq,
|
||||
created_at=model.created_at,
|
||||
)
|
||||
|
||||
|
||||
class DbRunEventRepository(RunEventRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def append_batch(self, events: list[RunEventCreate]) -> list[RunEvent]:
|
||||
if not events:
|
||||
return []
|
||||
|
||||
seq_base = _sequence_allocator.allocate_base(len(events))
|
||||
|
||||
rows: list[RunEventModel] = []
|
||||
|
||||
for index, event in enumerate(events, start=1):
|
||||
content, metadata = _serialize_content(event.content, dict(event.metadata))
|
||||
row = RunEventModel(
|
||||
thread_id=event.thread_id,
|
||||
run_id=event.run_id,
|
||||
user_id=event.user_id,
|
||||
seq=seq_base + index,
|
||||
event_type=event.event_type,
|
||||
category=event.category,
|
||||
content=content,
|
||||
meta=metadata,
|
||||
)
|
||||
if event.created_at is not None:
|
||||
row.created_at = event.created_at
|
||||
self._session.add(row)
|
||||
rows.append(row)
|
||||
|
||||
await self._session.flush()
|
||||
return [_to_run_event(row) for row in rows]
|
||||
|
||||
async def list_messages(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def list_events(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
event_types: list[str] | None = None,
|
||||
limit: int = 500,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
if event_types is not None:
|
||||
stmt = stmt.where(RunEventModel.event_type.in_(event_types))
|
||||
stmt = stmt.order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
async def list_messages_by_run(
|
||||
self,
|
||||
thread_id: str,
|
||||
run_id: str,
|
||||
*,
|
||||
limit: int = 50,
|
||||
before_seq: int | None = None,
|
||||
after_seq: int | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> list[RunEvent]:
|
||||
stmt = select(RunEventModel).where(
|
||||
RunEventModel.thread_id == thread_id,
|
||||
RunEventModel.run_id == run_id,
|
||||
RunEventModel.category == "message",
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
if before_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq < before_seq).order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
if after_seq is not None:
|
||||
stmt = stmt.where(RunEventModel.seq > after_seq).order_by(RunEventModel.seq.asc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_run_event(row) for row in result.scalars().all()]
|
||||
|
||||
stmt = stmt.order_by(RunEventModel.seq.desc()).limit(limit)
|
||||
result = await self._session.execute(stmt)
|
||||
return list(reversed([_to_run_event(row) for row in result.scalars().all()]))
|
||||
|
||||
async def count_messages(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
stmt = select(func.count()).select_from(RunEventModel).where(RunEventModel.thread_id == thread_id, RunEventModel.category == "message")
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(stmt)
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_thread(self, thread_id: str, *, user_id: str | None = None) -> int:
|
||||
conditions = [RunEventModel.thread_id == thread_id]
|
||||
if user_id is not None:
|
||||
conditions.append(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
|
||||
await self._session.execute(delete(RunEventModel).where(*conditions))
|
||||
return int(count or 0)
|
||||
|
||||
async def delete_by_run(self, thread_id: str, run_id: str, *, user_id: str | None = None) -> int:
|
||||
conditions = [RunEventModel.thread_id == thread_id, RunEventModel.run_id == run_id]
|
||||
if user_id is not None:
|
||||
conditions.append(RunEventModel.user_id == user_id)
|
||||
count = await self._session.scalar(select(func.count()).select_from(RunEventModel).where(*conditions))
|
||||
await self._session.execute(delete(RunEventModel).where(*conditions))
|
||||
return int(count or 0)
|
||||
@@ -0,0 +1,113 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.persistence.json_compat import json_match
|
||||
from store.repositories.contracts.thread_meta import (
|
||||
InvalidMetadataFilterError,
|
||||
ThreadMeta,
|
||||
ThreadMetaCreate,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.models.thread_meta import ThreadMeta as ThreadMetaModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _to_thread_meta(m: ThreadMetaModel) -> ThreadMeta:
|
||||
return ThreadMeta(
|
||||
thread_id=m.thread_id,
|
||||
assistant_id=m.assistant_id,
|
||||
user_id=m.user_id,
|
||||
display_name=m.display_name,
|
||||
status=m.status,
|
||||
metadata=dict(m.meta or {}),
|
||||
created_time=m.created_time,
|
||||
updated_time=m.updated_time,
|
||||
)
|
||||
|
||||
|
||||
class DbThreadMetaRepository(ThreadMetaRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_thread_meta(self, data: ThreadMetaCreate) -> ThreadMeta:
|
||||
model = ThreadMetaModel(
|
||||
thread_id=data.thread_id,
|
||||
assistant_id=data.assistant_id,
|
||||
user_id=data.user_id,
|
||||
display_name=data.display_name,
|
||||
status=data.status,
|
||||
meta=dict(data.metadata),
|
||||
)
|
||||
self._session.add(model)
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_thread_meta(model)
|
||||
|
||||
async def get_thread_meta(self, thread_id: str) -> ThreadMeta | None:
|
||||
result = await self._session.execute(select(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_thread_meta(model) if model else None
|
||||
|
||||
async def update_thread_meta(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
display_name: str | None = None,
|
||||
status: str | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
values: dict = {}
|
||||
if display_name is not None:
|
||||
values["display_name"] = display_name
|
||||
if status is not None:
|
||||
values["status"] = status
|
||||
if metadata is not None:
|
||||
values["meta"] = dict(metadata)
|
||||
if not values:
|
||||
return
|
||||
await self._session.execute(update(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id).values(**values))
|
||||
|
||||
async def delete_thread(self, thread_id: str) -> None:
|
||||
await self._session.execute(delete(ThreadMetaModel).where(ThreadMetaModel.thread_id == thread_id))
|
||||
|
||||
async def search_threads(
|
||||
self,
|
||||
*,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
status: str | None = None,
|
||||
user_id: str | None = None,
|
||||
assistant_id: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
) -> list[ThreadMeta]:
|
||||
stmt = select(ThreadMetaModel)
|
||||
|
||||
if status is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.status == status)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.user_id == user_id)
|
||||
if assistant_id is not None:
|
||||
stmt = stmt.where(ThreadMetaModel.assistant_id == assistant_id)
|
||||
if metadata:
|
||||
applied = 0
|
||||
for key, value in metadata.items():
|
||||
try:
|
||||
stmt = stmt.where(json_match(ThreadMetaModel.meta, key, value))
|
||||
applied += 1
|
||||
except (ValueError, TypeError) as exc:
|
||||
logger.warning("Skipping metadata filter key %s: %s", ascii(key), exc)
|
||||
if applied == 0:
|
||||
rejected_keys = ", ".join(sorted(str(key) for key in metadata))
|
||||
raise InvalidMetadataFilterError(f"All metadata filter keys were rejected as unsafe: {rejected_keys}")
|
||||
|
||||
stmt = stmt.order_by(ThreadMetaModel.created_time.desc(), ThreadMetaModel.thread_id.desc())
|
||||
stmt = stmt.limit(limit).offset(offset)
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return [_to_thread_meta(m) for m in result.scalars().all()]
|
||||
@@ -0,0 +1,98 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories.contracts.user import User, UserCreate, UserNotFoundError, UserRepositoryProtocol
|
||||
from store.repositories.models.user import User as UserModel
|
||||
|
||||
|
||||
def _to_user(model: UserModel) -> User:
|
||||
return User(
|
||||
id=model.id,
|
||||
email=model.email,
|
||||
password_hash=model.password_hash,
|
||||
system_role=model.system_role, # type: ignore[arg-type]
|
||||
created_at=model.created_at,
|
||||
oauth_provider=model.oauth_provider,
|
||||
oauth_id=model.oauth_id,
|
||||
needs_setup=model.needs_setup,
|
||||
token_version=model.token_version,
|
||||
)
|
||||
|
||||
|
||||
class DbUserRepository(UserRepositoryProtocol):
|
||||
def __init__(self, session: AsyncSession) -> None:
|
||||
self._session = session
|
||||
|
||||
async def create_user(self, data: UserCreate) -> User:
|
||||
model = UserModel(
|
||||
id=data.id,
|
||||
email=data.email,
|
||||
system_role=data.system_role,
|
||||
password_hash=data.password_hash,
|
||||
oauth_provider=data.oauth_provider,
|
||||
oauth_id=data.oauth_id,
|
||||
needs_setup=data.needs_setup,
|
||||
token_version=data.token_version,
|
||||
)
|
||||
if data.created_at is not None:
|
||||
model.created_at = data.created_at
|
||||
self._session.add(model)
|
||||
try:
|
||||
await self._session.flush()
|
||||
except IntegrityError as exc:
|
||||
await self._session.rollback()
|
||||
raise ValueError(f"Email already registered: {data.email}") from exc
|
||||
await self._session.refresh(model)
|
||||
return _to_user(model)
|
||||
|
||||
async def get_user_by_id(self, user_id: str) -> User | None:
|
||||
model = await self._session.get(UserModel, user_id)
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def get_user_by_email(self, email: str) -> User | None:
|
||||
result = await self._session.execute(select(UserModel).where(UserModel.email == email))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def get_user_by_oauth(self, provider: str, oauth_id: str) -> User | None:
|
||||
result = await self._session.execute(
|
||||
select(UserModel).where(
|
||||
UserModel.oauth_provider == provider,
|
||||
UserModel.oauth_id == oauth_id,
|
||||
)
|
||||
)
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def get_first_admin(self) -> User | None:
|
||||
result = await self._session.execute(select(UserModel).where(UserModel.system_role == "admin").limit(1))
|
||||
model = result.scalar_one_or_none()
|
||||
return _to_user(model) if model is not None else None
|
||||
|
||||
async def update_user(self, data: User) -> User:
|
||||
model = await self._session.get(UserModel, data.id)
|
||||
if model is None:
|
||||
raise UserNotFoundError(f"User {data.id} no longer exists")
|
||||
|
||||
model.email = data.email
|
||||
model.password_hash = data.password_hash
|
||||
model.system_role = data.system_role
|
||||
model.oauth_provider = data.oauth_provider
|
||||
model.oauth_id = data.oauth_id
|
||||
model.needs_setup = data.needs_setup
|
||||
model.token_version = data.token_version
|
||||
|
||||
await self._session.flush()
|
||||
await self._session.refresh(model)
|
||||
return _to_user(model)
|
||||
|
||||
async def count_users(self) -> int:
|
||||
count = await self._session.scalar(select(func.count()).select_from(UserModel))
|
||||
return int(count or 0)
|
||||
|
||||
async def count_admin_users(self) -> int:
|
||||
count = await self._session.scalar(select(func.count()).select_from(UserModel).where(UserModel.system_role == "admin"))
|
||||
return int(count or 0)
|
||||
@@ -0,0 +1,36 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from store.repositories import (
|
||||
FeedbackRepositoryProtocol,
|
||||
RunEventRepositoryProtocol,
|
||||
RunRepositoryProtocol,
|
||||
ThreadMetaRepositoryProtocol,
|
||||
UserRepositoryProtocol,
|
||||
)
|
||||
from store.repositories.db import (
|
||||
DbFeedbackRepository,
|
||||
DbRunEventRepository,
|
||||
DbRunRepository,
|
||||
DbThreadMetaRepository,
|
||||
DbUserRepository,
|
||||
)
|
||||
|
||||
|
||||
def build_thread_meta_repository(session: AsyncSession) -> ThreadMetaRepositoryProtocol:
|
||||
return DbThreadMetaRepository(session)
|
||||
|
||||
|
||||
def build_run_repository(session: AsyncSession) -> RunRepositoryProtocol:
|
||||
return DbRunRepository(session)
|
||||
|
||||
|
||||
def build_feedback_repository(session: AsyncSession) -> FeedbackRepositoryProtocol:
|
||||
return DbFeedbackRepository(session)
|
||||
|
||||
|
||||
def build_run_event_repository(session: AsyncSession) -> RunEventRepositoryProtocol:
|
||||
return DbRunEventRepository(session)
|
||||
|
||||
|
||||
def build_user_repository(session: AsyncSession) -> UserRepositoryProtocol:
|
||||
return DbUserRepository(session)
|
||||
@@ -0,0 +1,7 @@
|
||||
from store.repositories.models.feedback import Feedback
|
||||
from store.repositories.models.run import Run
|
||||
from store.repositories.models.run_event import RunEvent
|
||||
from store.repositories.models.thread_meta import ThreadMeta
|
||||
from store.repositories.models.user import User
|
||||
|
||||
__all__ = ["Feedback", "Run", "RunEvent", "ThreadMeta", "User"]
|
||||
@@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Integer, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||
|
||||
|
||||
class Feedback(DataClassBase):
|
||||
"""Feedback table (create-only, no updated_time)."""
|
||||
|
||||
__tablename__ = "feedback"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "run_id", "user_id", name="uq_feedback_thread_run_user"),
|
||||
{"comment": "Feedback table."},
|
||||
)
|
||||
|
||||
feedback_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
rating: Mapped[int] = mapped_column(Integer)
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
message_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
comment: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -0,0 +1,63 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, Index, Integer, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, UniversalText, current_time
|
||||
|
||||
|
||||
class Run(DataClassBase):
|
||||
"""Run metadata table."""
|
||||
|
||||
__tablename__ = "runs"
|
||||
__table_args__ = (
|
||||
Index("ix_runs_thread_status", "thread_id", "status"),
|
||||
{"comment": "Run metadata table."},
|
||||
)
|
||||
|
||||
run_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
|
||||
model_name: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
multitask_strategy: Mapped[str] = mapped_column(String(20), default="reject")
|
||||
error: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
follow_up_to_run_id: Mapped[str | None] = mapped_column(String(64), default=None)
|
||||
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
|
||||
kwargs: Mapped[dict[str, Any]] = mapped_column("kwargs_json", JSON, default_factory=dict)
|
||||
|
||||
total_input_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_output_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
total_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
llm_call_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
lead_agent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
subagent_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
middleware_tokens: Mapped[int] = mapped_column(Integer, default=0)
|
||||
|
||||
message_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||
first_human_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
last_ai_message: Mapped[str | None] = mapped_column(UniversalText, default=None)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
"updated_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
@@ -0,0 +1,46 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, Index, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import (
|
||||
DataClassBase,
|
||||
TimeZone,
|
||||
UniversalText,
|
||||
current_time,
|
||||
id_key,
|
||||
)
|
||||
|
||||
|
||||
class RunEvent(DataClassBase):
|
||||
"""Run event table."""
|
||||
|
||||
__tablename__ = "run_events"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("thread_id", "seq", name="uq_events_thread_seq"),
|
||||
Index("ix_events_thread_cat_seq", "thread_id", "category", "seq"),
|
||||
Index("ix_events_run", "thread_id", "run_id", "seq"),
|
||||
{"comment": "Run event table."},
|
||||
)
|
||||
|
||||
id: Mapped[id_key] = mapped_column(init=False)
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
run_id: Mapped[str] = mapped_column(String(64), index=True)
|
||||
event_type: Mapped[str] = mapped_column(String(32), index=True)
|
||||
category: Mapped[str] = mapped_column(String(16), index=True)
|
||||
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
seq: Mapped[int] = mapped_column(BigInteger, default=0, index=True)
|
||||
content: Mapped[str] = mapped_column(UniversalText, default="")
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("event_metadata", JSON, default_factory=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Event timestamp",
|
||||
)
|
||||
@@ -0,0 +1,43 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import JSON, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||
|
||||
|
||||
class ThreadMeta(DataClassBase):
|
||||
"""Thread metadata table."""
|
||||
|
||||
__tablename__ = "threads_meta"
|
||||
__table_args__ = {"comment": "Thread metadata table."}
|
||||
|
||||
thread_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
|
||||
assistant_id: Mapped[str | None] = mapped_column(String(128), default=None, index=True)
|
||||
user_id: Mapped[str | None] = mapped_column(String(64), default=None, index=True)
|
||||
display_name: Mapped[str | None] = mapped_column(String(256), default=None)
|
||||
status: Mapped[str] = mapped_column(String(20), default="idle", index=True)
|
||||
|
||||
meta: Mapped[dict[str, Any]] = mapped_column("metadata_json", JSON, default_factory=dict)
|
||||
|
||||
created_time: Mapped[datetime] = mapped_column(
|
||||
"created_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
updated_time: Mapped[datetime | None] = mapped_column(
|
||||
"updated_at",
|
||||
TimeZone,
|
||||
init=False,
|
||||
default=None,
|
||||
onupdate=current_time,
|
||||
sort_order=999,
|
||||
comment="Updated at",
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import Boolean, Index, String, text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from store.persistence.base_model import DataClassBase, TimeZone, current_time
|
||||
|
||||
|
||||
class User(DataClassBase):
|
||||
"""User account table."""
|
||||
|
||||
__tablename__ = "users"
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"idx_users_oauth_identity",
|
||||
"oauth_provider",
|
||||
"oauth_id",
|
||||
unique=True,
|
||||
sqlite_where=text("oauth_provider IS NOT NULL AND oauth_id IS NOT NULL"),
|
||||
),
|
||||
{"comment": "User account table."},
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True)
|
||||
email: Mapped[str] = mapped_column(String(320), unique=True, nullable=False, index=True)
|
||||
system_role: Mapped[str] = mapped_column(String(16), default="user")
|
||||
|
||||
password_hash: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
oauth_provider: Mapped[str | None] = mapped_column(String(32), default=None)
|
||||
oauth_id: Mapped[str | None] = mapped_column(String(128), default=None)
|
||||
needs_setup: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
token_version: Mapped[int] = mapped_column(default=0)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
TimeZone,
|
||||
init=False,
|
||||
default_factory=current_time,
|
||||
sort_order=999,
|
||||
comment="Created at",
|
||||
)
|
||||
@@ -0,0 +1,3 @@
|
||||
from .timezone import get_timezone
|
||||
|
||||
__all__ = ["get_timezone"]
|
||||
@@ -0,0 +1,51 @@
|
||||
import zoneinfo
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from store.config.app_config import get_app_config
|
||||
|
||||
# IANA identifiers that map to UTC — see https://en.wikipedia.org/wiki/List_of_tz_database_time_zones
|
||||
_UTC_IDENTIFIERS = frozenset({"Etc/UCT", "Etc/Universal", "Etc/UTC", "Etc/Zulu", "UCT", "Universal", "UTC", "Zulu"})
|
||||
|
||||
|
||||
class TimeZone:
|
||||
def __init__(self) -> None:
|
||||
app_config = get_app_config()
|
||||
if app_config.timezone in _UTC_IDENTIFIERS:
|
||||
self.tz_info = UTC
|
||||
else:
|
||||
self.tz_info = zoneinfo.ZoneInfo(app_config.timezone)
|
||||
|
||||
def now(self) -> datetime:
|
||||
"""Return the current time in the configured timezone."""
|
||||
return datetime.now(self.tz_info)
|
||||
|
||||
def from_datetime(self, t: datetime) -> datetime:
|
||||
"""Convert a datetime to the configured timezone."""
|
||||
return t.astimezone(self.tz_info)
|
||||
|
||||
def from_str(self, t_str: str, format_str: str = "%Y-%m-%d %H:%M:%S") -> datetime:
|
||||
"""Parse a time string and attach the configured timezone."""
|
||||
return datetime.strptime(t_str, format_str).replace(tzinfo=self.tz_info)
|
||||
|
||||
@staticmethod
|
||||
def to_str(t: datetime, format_str: str = "%Y-%m-%d %H:%M:%S") -> str:
|
||||
"""Format a datetime to string."""
|
||||
return t.strftime(format_str)
|
||||
|
||||
@staticmethod
|
||||
def to_utc(t: datetime | int) -> datetime:
|
||||
"""Convert a datetime or Unix timestamp to UTC."""
|
||||
if isinstance(t, datetime):
|
||||
return t.astimezone(UTC)
|
||||
return datetime.fromtimestamp(t, tz=UTC)
|
||||
|
||||
|
||||
_timezone = None
|
||||
|
||||
|
||||
def get_timezone() -> TimeZone:
|
||||
"""Return the global TimeZone singleton (lazy-initialized)."""
|
||||
global _timezone
|
||||
if _timezone is None:
|
||||
_timezone = TimeZone()
|
||||
return _timezone
|
||||
@@ -6,6 +6,7 @@ readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"deerflow-harness",
|
||||
"deerflow-storage",
|
||||
"fastapi>=0.115.0",
|
||||
"httpx>=0.28.0",
|
||||
"python-multipart>=0.0.27",
|
||||
@@ -24,8 +25,8 @@ dependencies = [
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
postgres = ["deerflow-harness[postgres]"]
|
||||
discord = ["discord.py>=2.7.0"]
|
||||
postgres = ["deerflow-harness[postgres]", "deerflow-storage[postgres]"]
|
||||
mysql = ["deerflow-storage[mysql]"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
@@ -44,7 +45,8 @@ markers = [
|
||||
index-url = "https://pypi.org/simple"
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["packages/harness"]
|
||||
members = ["packages/harness", "packages/storage"]
|
||||
|
||||
[tool.uv.sources]
|
||||
deerflow-harness = { workspace = true }
|
||||
deerflow-storage = { workspace = true }
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
@@ -103,17 +102,3 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path
|
||||
assert response.status_code == 200
|
||||
assert response.text == "hello"
|
||||
assert response.headers.get("content-disposition", "").startswith("attachment;")
|
||||
|
||||
|
||||
def test_skill_archive_preview_rejects_oversized_member_before_decompression(tmp_path) -> None:
|
||||
skill_path = tmp_path / "sample.skill"
|
||||
payload = b"A" * (artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + 1)
|
||||
with zipfile.ZipFile(skill_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zip_ref:
|
||||
zip_ref.writestr("SKILL.md", payload)
|
||||
|
||||
assert skill_path.stat().st_size < artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
artifacts_router._extract_file_from_skill_archive(skill_path, "SKILL.md")
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
|
||||
@@ -5,26 +5,28 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
from app.gateway.auth.config import AuthConfig
|
||||
|
||||
|
||||
def test_auth_config_defaults():
|
||||
config = cfg.AuthConfig(jwt_secret="test-secret-key-123")
|
||||
config = AuthConfig(jwt_secret="test-secret-key-123")
|
||||
assert config.token_expiry_days == 7
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_range():
|
||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||
with pytest.raises(Exception):
|
||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||
with pytest.raises(Exception):
|
||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_from_env():
|
||||
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
@@ -34,57 +36,19 @@ def test_auth_config_from_env():
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_missing_secret_generates_and_persists(tmp_path, caplog):
|
||||
def test_auth_config_missing_secret_generates_ephemeral(caplog):
|
||||
import logging
|
||||
|
||||
from deerflow.config.paths import Paths
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
secret_file = tmp_path / ".jwt_secret"
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)), caplog.at_level(logging.WARNING):
|
||||
with caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
assert secret_file.exists()
|
||||
assert secret_file.read_text().strip() == config.jwt_secret
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_reuses_persisted_secret(tmp_path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
persisted = "persisted-secret-from-file-min-32-chars!!"
|
||||
(tmp_path / ".jwt_secret").write_text(persisted, encoding="utf-8")
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret == persisted
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_empty_secret_file_generates_new(tmp_path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
(tmp_path / ".jwt_secret").write_text("", encoding="utf-8")
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert len(config.jwt_secret) > 20
|
||||
assert (tmp_path / ".jwt_secret").read_text().strip() == config.jwt_secret
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
@@ -761,7 +761,7 @@ class TestChannelManager:
|
||||
|
||||
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
|
||||
|
||||
async def _runs_wait(thread_id, assistant_id, *, input, config, context, multitask_strategy=None):
|
||||
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
|
||||
del assistant_id, context # unused in this test, kept for signature parity
|
||||
|
||||
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
|
||||
|
||||
@@ -94,12 +94,15 @@ class TestHarnessPackaging:
|
||||
"psycopg-pool>=3.3.0",
|
||||
]
|
||||
|
||||
def test_workspace_pyproject_forwards_postgres_extra_to_harness(self):
|
||||
def test_workspace_pyproject_forwards_postgres_extra_to_storage_packages(self):
|
||||
pyproject_path = Path(__file__).resolve().parents[1] / "pyproject.toml"
|
||||
data = tomllib.loads(pyproject_path.read_text())
|
||||
|
||||
optional_dependencies = data["project"]["optional-dependencies"]
|
||||
assert optional_dependencies["postgres"] == ["deerflow-harness[postgres]"]
|
||||
assert optional_dependencies["postgres"] == [
|
||||
"deerflow-harness[postgres]",
|
||||
"deerflow-storage[postgres]",
|
||||
]
|
||||
|
||||
def test_postgres_missing_dependency_messages_recommend_package_extra(self):
|
||||
assert "deerflow-harness[postgres]" in POSTGRES_INSTALL
|
||||
|
||||
@@ -158,88 +158,6 @@ class TestBuildPatchedMessagesPatching:
|
||||
assert patched[1].name == "bash"
|
||||
assert patched[1].status == "error"
|
||||
|
||||
def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self):
|
||||
middleware = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
HumanMessage(content="interruption"),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
patched = middleware._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert isinstance(patched[2], HumanMessage)
|
||||
|
||||
def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
|
||||
HumanMessage(content="interruption"),
|
||||
_tool_msg("call_2", "read"),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert isinstance(patched[2], ToolMessage)
|
||||
assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"]
|
||||
assert isinstance(patched[3], HumanMessage)
|
||||
|
||||
def test_valid_adjacent_tool_results_are_unchanged(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
_tool_msg("call_1", "bash"),
|
||||
HumanMessage(content="next"),
|
||||
]
|
||||
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
HumanMessage(content="interruption"),
|
||||
_ai_with_tool_calls([_tc("read", "call_2")]),
|
||||
_tool_msg("call_1", "bash"),
|
||||
_tool_msg("call_2", "read"),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert isinstance(patched[2], HumanMessage)
|
||||
assert isinstance(patched[3], AIMessage)
|
||||
assert isinstance(patched[4], ToolMessage)
|
||||
assert patched[4].tool_call_id == "call_2"
|
||||
|
||||
def test_orphan_tool_message_is_preserved_during_grouping(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
orphan = _tool_msg("orphan_call", "orphan")
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
orphan,
|
||||
HumanMessage(content="interruption"),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert orphan in patched
|
||||
assert patched.count(orphan) == 1
|
||||
|
||||
def test_invalid_tool_call_is_patched(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
|
||||
|
||||
@@ -454,6 +454,7 @@ class TestAStream:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_tools_emits_tool_call_chunk(self):
|
||||
|
||||
tool_calls = [{"name": "fn", "args": {}, "id": "c1"}]
|
||||
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
|
||||
mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls)
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import Column, MetaData, String, Table
|
||||
from sqlalchemy.dialects import mysql, postgresql
|
||||
from sqlalchemy.types import JSON
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.persistence.json_compat import json_match
|
||||
|
||||
|
||||
def _table():
|
||||
metadata = MetaData()
|
||||
return Table("t", metadata, Column("data", JSON), Column("id", String))
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_sqlite() -> None:
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
table = _table()
|
||||
dialect = create_engine("sqlite://").dialect
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'null'")
|
||||
assert str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_type(t.data, '$.\"k\"') = 'true'")
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "= 'integer'" in int_sql
|
||||
assert "CAST" in int_sql
|
||||
|
||||
float_sql = str(json_match(table.c.data, "k", 3.14).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "IN ('integer', 'real')" in float_sql
|
||||
assert "REAL" in float_sql
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_postgres() -> None:
|
||||
table = _table()
|
||||
dialect = postgresql.dialect()
|
||||
|
||||
assert str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("json_typeof(t.data -> 'k') = 'null'")
|
||||
assert str(json_match(table.c.data, "k", False).compile(dialect=dialect, compile_kwargs={"literal_binds": True})) == ("(json_typeof(t.data -> 'k') = 'boolean' AND (t.data ->> 'k') = 'false')")
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "CASE WHEN" in int_sql
|
||||
assert "BIGINT" in int_sql
|
||||
assert "'^-?[0-9]+$'" in int_sql
|
||||
|
||||
|
||||
def test_storage_json_match_compiles_mysql() -> None:
|
||||
table = _table()
|
||||
dialect = mysql.dialect()
|
||||
|
||||
null_sql = str(json_match(table.c.data, "k", None).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert null_sql == "JSON_TYPE(JSON_EXTRACT(t.data, '$.\"k\"')) = 'NULL'"
|
||||
|
||||
bool_sql = str(json_match(table.c.data, "k", True).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "JSON_TYPE(JSON_EXTRACT" in bool_sql
|
||||
assert "= 'BOOLEAN'" in bool_sql
|
||||
assert "= 'true'" in bool_sql
|
||||
|
||||
int_sql = str(json_match(table.c.data, "k", 42).compile(dialect=dialect, compile_kwargs={"literal_binds": True}))
|
||||
assert "= 'INTEGER'" in int_sql
|
||||
assert "SIGNED" in int_sql
|
||||
|
||||
|
||||
def test_storage_json_match_rejects_unsafe_keys_and_values() -> None:
|
||||
table = _table()
|
||||
|
||||
for bad_key in ["a.b", "bad;key", "with space", "", 42, None]:
|
||||
with pytest.raises(ValueError, match="JsonMatch key must match"):
|
||||
json_match(table.c.data, bad_key, "x") # type: ignore[arg-type]
|
||||
|
||||
for bad_value in [[], {}, object()]:
|
||||
with pytest.raises(TypeError, match="JsonMatch value must be"):
|
||||
json_match(table.c.data, "k", bad_value)
|
||||
|
||||
with pytest.raises(TypeError, match="out of signed 64-bit range"):
|
||||
json_match(table.c.data, "k", 2**63)
|
||||
@@ -0,0 +1,122 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.config.storage_config import StorageConfig
|
||||
from store.persistence.factory import _create_database_url, storage_config_from_database_config
|
||||
|
||||
|
||||
def test_database_sqlite_config_maps_to_storage_config(tmp_path):
|
||||
database = SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=True,
|
||||
pool_size=9,
|
||||
)
|
||||
|
||||
storage = storage_config_from_database_config(database)
|
||||
|
||||
assert storage == StorageConfig(
|
||||
driver="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=True,
|
||||
pool_size=9,
|
||||
)
|
||||
assert storage.sqlite_storage_path == str(tmp_path / "deerflow.db")
|
||||
|
||||
|
||||
def test_database_memory_config_is_not_a_storage_backend():
|
||||
database = SimpleNamespace(backend="memory")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_database_postgres_config_preserves_url_and_pool_options():
|
||||
database = SimpleNamespace(
|
||||
backend="postgres",
|
||||
postgres_url="postgresql://user:pass@db.example:5544/deerflow",
|
||||
echo_sql=True,
|
||||
pool_size=11,
|
||||
)
|
||||
|
||||
storage = storage_config_from_database_config(database)
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert storage.driver == "postgres"
|
||||
assert storage.database_url == "postgresql://user:pass@db.example:5544/deerflow"
|
||||
assert storage.username == "user"
|
||||
assert storage.password == "pass"
|
||||
assert storage.host == "db.example"
|
||||
assert storage.port == 5544
|
||||
assert storage.db_name == "deerflow"
|
||||
assert storage.echo_sql is True
|
||||
assert storage.pool_size == 11
|
||||
assert url.drivername == "postgresql+asyncpg"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_mysql_database_url_is_normalized_to_async_driver():
|
||||
storage = StorageConfig(
|
||||
driver="mysql",
|
||||
database_url="mysql://user:pass@db.example:3306/deerflow",
|
||||
)
|
||||
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert url.drivername == "mysql+aiomysql"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_mysql_async_database_url_is_preserved():
|
||||
storage = StorageConfig(
|
||||
driver="mysql",
|
||||
database_url="mysql+asyncmy://user:pass@db.example:3306/deerflow",
|
||||
)
|
||||
|
||||
url = _create_database_url(storage)
|
||||
|
||||
assert url.drivername == "mysql+asyncmy"
|
||||
assert url.database == "deerflow"
|
||||
|
||||
|
||||
def test_database_postgres_requires_url():
|
||||
database = SimpleNamespace(backend="postgres", postgres_url="")
|
||||
|
||||
with pytest.raises(ValueError, match="database.postgres_url is required"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_unsupported_database_backend_rejected():
|
||||
database = SimpleNamespace(backend="oracle")
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported database backend"):
|
||||
storage_config_from_database_config(database)
|
||||
|
||||
|
||||
def test_storage_models_import_without_config_file(tmp_path):
|
||||
env = os.environ.copy()
|
||||
env["DEER_FLOW_CONFIG_PATH"] = str(tmp_path / "missing-config.yaml")
|
||||
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"-c",
|
||||
"from store.persistence.base_model import UniversalText, id_key; from store.repositories.models import RunEvent; print(UniversalText.__name__, RunEvent.__tablename__, id_key)",
|
||||
],
|
||||
check=False,
|
||||
capture_output=True,
|
||||
env=env,
|
||||
text=True,
|
||||
)
|
||||
|
||||
assert result.returncode == 0, result.stderr
|
||||
assert "UniversalText run_events" in result.stdout
|
||||
@@ -0,0 +1,58 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
from uuid import uuid4
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import UserCreate, build_user_repository
|
||||
|
||||
|
||||
def test_sqlite_persistence_from_database_config_creates_storage_tables(tmp_path):
|
||||
async def run() -> None:
|
||||
persistence = await create_persistence_from_database_config(
|
||||
SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=False,
|
||||
pool_size=5,
|
||||
)
|
||||
)
|
||||
assert persistence is not None
|
||||
try:
|
||||
await persistence.setup()
|
||||
|
||||
async with persistence.engine.connect() as conn:
|
||||
tables = await conn.run_sync(lambda sync_conn: set(inspect(sync_conn).get_table_names()))
|
||||
|
||||
assert {
|
||||
"users",
|
||||
"runs",
|
||||
"run_events",
|
||||
"threads_meta",
|
||||
"feedback",
|
||||
}.issubset(tables)
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email="storage-user@example.com",
|
||||
password_hash="hash",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
assert await repo.get_user_by_id(user.id) == user
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
asyncio.run(run())
|
||||
@@ -0,0 +1,395 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from store.persistence import create_persistence_from_database_config
|
||||
from store.repositories import (
|
||||
FeedbackCreate,
|
||||
InvalidMetadataFilterError,
|
||||
RunCreate,
|
||||
RunEventCreate,
|
||||
ThreadMetaCreate,
|
||||
build_feedback_repository,
|
||||
build_run_event_repository,
|
||||
build_run_repository,
|
||||
build_thread_meta_repository,
|
||||
)
|
||||
|
||||
|
||||
async def _make_persistence(tmp_path):
|
||||
persistence = await create_persistence_from_database_config(
|
||||
SimpleNamespace(
|
||||
backend="sqlite",
|
||||
sqlite_dir=str(tmp_path),
|
||||
echo_sql=False,
|
||||
pool_size=5,
|
||||
)
|
||||
)
|
||||
await persistence.setup()
|
||||
return persistence
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_run_repository_filters_and_aggregates(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
old = datetime.now(UTC) - timedelta(hours=1)
|
||||
newer = datetime.now(UTC)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id="run-old",
|
||||
thread_id="thread-1",
|
||||
user_id="alice",
|
||||
status="pending",
|
||||
model_name="model-a",
|
||||
metadata={"kind": "draft"},
|
||||
kwargs={"temperature": 0.2},
|
||||
created_time=old,
|
||||
)
|
||||
)
|
||||
await repo.create_run(
|
||||
RunCreate(
|
||||
run_id="run-new",
|
||||
thread_id="thread-1",
|
||||
user_id="bob",
|
||||
status="running",
|
||||
model_name="model-b",
|
||||
error="queued",
|
||||
created_time=newer,
|
||||
)
|
||||
)
|
||||
await repo.create_run(RunCreate(run_id="run-other", thread_id="thread-2", status="running"))
|
||||
await repo.update_run_completion(
|
||||
"run-old",
|
||||
status="success",
|
||||
total_input_tokens=7,
|
||||
total_output_tokens=3,
|
||||
total_tokens=10,
|
||||
llm_call_count=1,
|
||||
lead_agent_tokens=8,
|
||||
subagent_tokens=2,
|
||||
first_human_message="hello",
|
||||
last_ai_message="world",
|
||||
)
|
||||
await repo.update_run_completion(
|
||||
"run-new",
|
||||
status="error",
|
||||
total_tokens=5,
|
||||
middleware_tokens=5,
|
||||
error="failed",
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_repository(session)
|
||||
fetched = await repo.get_run("run-old")
|
||||
assert fetched is not None
|
||||
assert fetched.metadata == {"kind": "draft"}
|
||||
assert fetched.kwargs == {"temperature": 0.2}
|
||||
assert fetched.first_human_message == "hello"
|
||||
assert fetched.last_ai_message == "world"
|
||||
|
||||
all_thread_runs = await repo.list_runs_by_thread("thread-1")
|
||||
assert [run.run_id for run in all_thread_runs] == ["run-new", "run-old"]
|
||||
alice_runs = await repo.list_runs_by_thread("thread-1", user_id="alice")
|
||||
assert [run.run_id for run in alice_runs] == ["run-old"]
|
||||
|
||||
pending = await repo.list_pending(before=datetime.now(UTC).isoformat())
|
||||
assert [run.run_id for run in pending] == []
|
||||
|
||||
agg = await repo.aggregate_tokens_by_thread("thread-1")
|
||||
assert agg["total_tokens"] == 15
|
||||
assert agg["total_input_tokens"] == 7
|
||||
assert agg["total_output_tokens"] == 3
|
||||
assert agg["total_runs"] == 2
|
||||
assert agg["by_model"] == {
|
||||
"model-a": {"tokens": 10, "runs": 1},
|
||||
"model-b": {"tokens": 5, "runs": 1},
|
||||
}
|
||||
assert agg["by_caller"] == {"lead_agent": 8, "subagent": 2, "middleware": 5}
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_repository_search_update_delete(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id="thread-1",
|
||||
assistant_id="agent-a",
|
||||
user_id="alice",
|
||||
display_name="Initial",
|
||||
status="idle",
|
||||
metadata={"topic": "finance", "region": "cn"},
|
||||
)
|
||||
)
|
||||
await repo.create_thread_meta(
|
||||
ThreadMetaCreate(
|
||||
thread_id="thread-2",
|
||||
assistant_id="agent-b",
|
||||
user_id="bob",
|
||||
status="running",
|
||||
metadata={"topic": "legal"},
|
||||
)
|
||||
)
|
||||
await repo.update_thread_meta(
|
||||
"thread-1",
|
||||
display_name="Updated",
|
||||
status="running",
|
||||
metadata={"topic": "finance", "region": "us"},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
fetched = await repo.get_thread_meta("thread-1")
|
||||
assert fetched is not None
|
||||
assert fetched.display_name == "Updated"
|
||||
assert fetched.status == "running"
|
||||
assert fetched.metadata == {"topic": "finance", "region": "us"}
|
||||
|
||||
by_metadata = await repo.search_threads(metadata={"topic": "finance"}, user_id="alice")
|
||||
assert [thread.thread_id for thread in by_metadata] == ["thread-1"]
|
||||
by_assistant = await repo.search_threads(assistant_id="agent-b")
|
||||
assert [thread.thread_id for thread in by_assistant] == ["thread-2"]
|
||||
|
||||
await repo.delete_thread("thread-1")
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
assert await repo.get_thread_meta("thread-1") is None
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filters_are_type_safe(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-true", metadata={"value": True}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="bool-false", metadata={"value": False}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="int-one", metadata={"value": 1}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="null-value", metadata={"value": None}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="missing-value", metadata={"other": "x"}))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": True})] == ["bool-true"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": False})] == ["bool-false"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": 1})] == ["int-one"]
|
||||
assert [row.thread_id for row in await repo.search_threads(metadata={"value": None})] == ["null-value"]
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filters_paginate_after_sql_match(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
for index in range(30):
|
||||
metadata = {"target": "yes"} if index % 3 == 0 else {"target": "no"}
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id=f"thread-{index:02d}", metadata=metadata))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
first_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=0)
|
||||
second_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=3)
|
||||
last_page = await repo.search_threads(metadata={"target": "yes"}, limit=3, offset=9)
|
||||
|
||||
assert len(first_page) == 3
|
||||
assert len(second_page) == 3
|
||||
assert len(last_page) == 1
|
||||
assert {row.thread_id for row in first_page}.isdisjoint({row.thread_id for row in second_page})
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_thread_meta_metadata_filter_rejects_invalid_entries(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-1", metadata={"env": "prod"}))
|
||||
await repo.create_thread_meta(ThreadMetaCreate(thread_id="thread-2", metadata={"env": "staging"}))
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_thread_meta_repository(session)
|
||||
partial = await repo.search_threads(metadata={"env": "prod", "bad;key": "ignored"})
|
||||
assert [row.thread_id for row in partial] == ["thread-1"]
|
||||
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search_threads(metadata={"bad;key": "x"})
|
||||
with pytest.raises(InvalidMetadataFilterError, match="rejected"):
|
||||
await repo.search_threads(metadata={"env": ["prod", "staging"]})
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_feedback_repository_lists_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
first = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-1",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=1,
|
||||
user_id="alice",
|
||||
message_id="msg-1",
|
||||
comment="good",
|
||||
)
|
||||
)
|
||||
second = await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-2",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=-1,
|
||||
user_id="bob",
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
assert await repo.get_feedback(first.feedback_id) == first
|
||||
assert [item.feedback_id for item in await repo.list_feedback_by_run("run-1")] == [
|
||||
second.feedback_id,
|
||||
first.feedback_id,
|
||||
]
|
||||
assert {item.feedback_id for item in await repo.list_feedback_by_thread("thread-1")} == {
|
||||
"fb-1",
|
||||
"fb-2",
|
||||
}
|
||||
assert await repo.delete_feedback("fb-1") is True
|
||||
assert await repo.delete_feedback("missing") is False
|
||||
with pytest.raises(ValueError, match="rating must be"):
|
||||
await repo.create_feedback(
|
||||
FeedbackCreate(
|
||||
feedback_id="fb-bad",
|
||||
run_id="run-1",
|
||||
thread_id="thread-1",
|
||||
rating=0,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_feedback_repository(session)
|
||||
assert await repo.get_feedback("fb-1") is None
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_storage_run_event_repository_sequences_paginates_and_deletes(tmp_path):
|
||||
persistence = await _make_persistence(tmp_path)
|
||||
try:
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
rows = await repo.append_batch(
|
||||
[
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
user_id="alice",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content={"role": "user", "content": "hello"},
|
||||
metadata={"source": "input"},
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-1",
|
||||
event_type="tool",
|
||||
category="debug",
|
||||
content="tool-call",
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-2",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="second",
|
||||
),
|
||||
RunEventCreate(
|
||||
thread_id="thread-2",
|
||||
run_id="run-3",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="other-thread",
|
||||
),
|
||||
]
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
assert [row.thread_id for row in rows] == ["thread-1", "thread-1", "thread-1", "thread-2"]
|
||||
assert [row.seq for row in rows] == sorted(row.seq for row in rows)
|
||||
assert rows[1].seq == rows[0].seq + 1
|
||||
assert rows[2].seq == rows[1].seq + 1
|
||||
assert rows[0].content == {"role": "user", "content": "hello"}
|
||||
assert rows[0].metadata == {"source": "input", "content_is_json": True}
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
messages = await repo.list_messages("thread-1", limit=2)
|
||||
assert [event.seq for event in messages] == [rows[0].seq, rows[2].seq]
|
||||
assert await repo.count_messages("thread-1") == 2
|
||||
|
||||
after = await repo.list_messages_by_run("thread-1", "run-1", after_seq=0, limit=5)
|
||||
assert [event.seq for event in after] == [rows[0].seq]
|
||||
before = await repo.list_messages("thread-1", before_seq=rows[2].seq, limit=5)
|
||||
assert [event.seq for event in before] == [rows[0].seq]
|
||||
|
||||
events = await repo.list_events("thread-1", "run-1", event_types=["tool"])
|
||||
assert [event.content for event in events] == ["tool-call"]
|
||||
|
||||
assert await repo.delete_by_run("thread-1", "run-1") == 2
|
||||
assert await repo.delete_by_thread("thread-2") == 1
|
||||
await session.commit()
|
||||
|
||||
async with persistence.session_factory() as session:
|
||||
repo = build_run_event_repository(session)
|
||||
remaining = await repo.list_events("thread-1", "run-2")
|
||||
assert [event.seq for event in remaining] == [rows[2].seq]
|
||||
assert await repo.count_messages("thread-2") == 0
|
||||
|
||||
later = await repo.append_batch(
|
||||
[
|
||||
RunEventCreate(
|
||||
thread_id="thread-1",
|
||||
run_id="run-4",
|
||||
event_type="message",
|
||||
category="message",
|
||||
content="after-delete",
|
||||
)
|
||||
]
|
||||
)
|
||||
assert later[0].seq > rows[2].seq
|
||||
finally:
|
||||
await persistence.aclose()
|
||||
@@ -0,0 +1,177 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("DEER_FLOW_CONFIG_PATH", str(Path(__file__).resolve().parents[2] / "config.example.yaml"))
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from store.repositories import UserCreate, UserNotFoundError, build_user_repository
|
||||
from store.repositories.models import User as UserModel
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def _session_factory(tmp_path) -> AsyncGenerator[async_sessionmaker[AsyncSession]]:
|
||||
db_path = tmp_path / "storage-users.db"
|
||||
engine = create_async_engine(f"sqlite+aiosqlite:///{db_path}")
|
||||
async with engine.begin() as conn:
|
||||
await conn.run_sync(UserModel.metadata.create_all)
|
||||
|
||||
try:
|
||||
yield async_sessionmaker(engine, expire_on_commit=False)
|
||||
finally:
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
async def _create_user(
|
||||
session_factory: async_sessionmaker[AsyncSession],
|
||||
*,
|
||||
email: str = "user@example.com",
|
||||
system_role: str = "user",
|
||||
oauth_provider: str | None = None,
|
||||
oauth_id: str | None = None,
|
||||
):
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
user = await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email=email,
|
||||
password_hash="hash",
|
||||
system_role=system_role, # type: ignore[arg-type]
|
||||
oauth_provider=oauth_provider,
|
||||
oauth_id=oauth_id,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def test_create_and_get_user_by_id_and_email(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
created = await _create_user(session_factory)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
by_id = await repo.get_user_by_id(created.id)
|
||||
by_email = await repo.get_user_by_email(created.email)
|
||||
|
||||
assert by_id == created
|
||||
assert by_email == created
|
||||
assert created.system_role == "user"
|
||||
assert created.needs_setup is False
|
||||
assert created.token_version == 0
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_duplicate_email_raises_value_error(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="dupe@example.com")
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
with pytest.raises(ValueError, match="Email already registered"):
|
||||
await repo.create_user(
|
||||
UserCreate(
|
||||
id=str(uuid4()),
|
||||
email="dupe@example.com",
|
||||
password_hash="hash",
|
||||
)
|
||||
)
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_oauth_lookup_and_plain_users_without_oauth(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="local-1@example.com")
|
||||
await _create_user(session_factory, email="local-2@example.com")
|
||||
oauth_user = await _create_user(
|
||||
session_factory,
|
||||
email="oauth@example.com",
|
||||
oauth_provider="github",
|
||||
oauth_id="gh-123",
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
assert await repo.count_users() == 3
|
||||
assert await repo.get_user_by_oauth("github", "gh-123") == oauth_user
|
||||
assert await repo.get_user_by_oauth("github", "missing") is None
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_count_admins_and_get_first_admin(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
await _create_user(session_factory, email="user@example.com")
|
||||
admin = await _create_user(
|
||||
session_factory,
|
||||
email="admin@example.com",
|
||||
system_role="admin",
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
|
||||
assert await repo.count_users() == 2
|
||||
assert await repo.count_admin_users() == 1
|
||||
assert await repo.get_first_admin() == admin
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_update_user_round_trips_token_version_and_setup_state(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
created = await _create_user(session_factory)
|
||||
updated = created.model_copy(
|
||||
update={
|
||||
"email": "renamed@example.com",
|
||||
"token_version": 4,
|
||||
"needs_setup": True,
|
||||
}
|
||||
)
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
saved = await repo.update_user(updated)
|
||||
await session.commit()
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
fetched = await repo.get_user_by_id(created.id)
|
||||
|
||||
assert saved.email == "renamed@example.com"
|
||||
assert fetched == updated
|
||||
|
||||
asyncio.run(run())
|
||||
|
||||
|
||||
def test_update_missing_user_raises(tmp_path):
|
||||
async def run() -> None:
|
||||
async with _session_factory(tmp_path) as session_factory:
|
||||
missing = UserCreate(id=str(uuid4()), email="missing@example.com")
|
||||
|
||||
async with session_factory() as session:
|
||||
repo = build_user_repository(session)
|
||||
created_shape = await repo.create_user(missing)
|
||||
await session.rollback()
|
||||
|
||||
with pytest.raises(UserNotFoundError):
|
||||
await repo.update_user(created_shape)
|
||||
|
||||
asyncio.run(run())
|
||||
@@ -56,8 +56,7 @@ def _middleware(
|
||||
preserve_recent_skill_tokens_per_skill: int = 0,
|
||||
) -> DeerFlowSummarizationMiddleware:
|
||||
model = MagicMock()
|
||||
model.invoke.return_value = AIMessage(content="compressed summary")
|
||||
model.with_config.return_value.invoke.return_value = AIMessage(content="compressed summary")
|
||||
model.invoke.return_value = SimpleNamespace(text="compressed summary")
|
||||
return DeerFlowSummarizationMiddleware(
|
||||
model=model,
|
||||
trigger=trigger,
|
||||
@@ -643,69 +642,6 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
|
||||
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Issue #2804: summary text must not leak to the frontend via streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_new_messages_sets_hide_from_ui() -> None:
|
||||
"""The summary HumanMessage must carry hide_from_ui so the frontend filters it."""
|
||||
middleware = _middleware()
|
||||
messages = middleware._build_new_messages("test summary")
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.name == "summary"
|
||||
assert msg.additional_kwargs.get("hide_from_ui") is True
|
||||
assert "test summary" in msg.content
|
||||
|
||||
|
||||
def test_create_summary_suppresses_callbacks() -> None:
|
||||
"""_create_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
||||
in the invoke config to suppress inherited LangGraph stream callbacks."""
|
||||
middleware = _middleware()
|
||||
|
||||
middleware._create_summary(_messages())
|
||||
|
||||
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
||||
bound = middleware.model.with_config.return_value
|
||||
bound.invoke.assert_called_once()
|
||||
call_config = bound.invoke.call_args.kwargs.get("config") or bound.invoke.call_args[1].get("config")
|
||||
assert call_config is not None
|
||||
assert call_config.get("callbacks") == []
|
||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_acreate_summary_suppresses_callbacks() -> None:
|
||||
"""_acreate_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
||||
in the ainvoke config to suppress inherited LangGraph stream callbacks."""
|
||||
middleware = _middleware()
|
||||
middleware.model.with_config.return_value.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
|
||||
|
||||
await middleware._acreate_summary(_messages())
|
||||
|
||||
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
||||
bound = middleware.model.with_config.return_value
|
||||
bound.ainvoke.assert_called_once()
|
||||
call_config = bound.ainvoke.call_args.kwargs.get("config") or bound.ainvoke.call_args[1].get("config")
|
||||
assert call_config is not None
|
||||
assert call_config.get("callbacks") == []
|
||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
||||
|
||||
|
||||
def test_before_model_summary_message_has_hide_from_ui() -> None:
|
||||
"""End-to-end: the emitted state update contains a summary message with hide_from_ui."""
|
||||
middleware = _middleware()
|
||||
|
||||
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
emitted = result["messages"]
|
||||
summary_msg = emitted[1]
|
||||
assert summary_msg.name == "summary"
|
||||
assert summary_msg.additional_kwargs.get("hide_from_ui") is True
|
||||
|
||||
|
||||
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
@@ -723,17 +659,3 @@ def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatc
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
||||
|
||||
|
||||
def test_extract_summary_text_normalizes_list_content_blocks() -> None:
|
||||
"""AIMessage.content can be a list of content blocks; _extract_summary_text
|
||||
must normalize to plain text via the .text property instead of producing
|
||||
a Python repr like [{'type': 'text', 'text': 'summary'}]."""
|
||||
middleware = _middleware()
|
||||
|
||||
response = AIMessage(content=[{"type": "text", "text": "A summary of the chat."}])
|
||||
assert middleware._extract_summary_text(response) == "A summary of the chat."
|
||||
|
||||
# Plain string content still works
|
||||
response_str = AIMessage(content="Plain summary")
|
||||
assert middleware._extract_summary_text(response_str) == "Plain summary"
|
||||
|
||||
@@ -1214,11 +1214,12 @@ def test_terminal_event_usage_none_when_no_records(monkeypatch):
|
||||
assert completed[0]["usage"] is None
|
||||
|
||||
|
||||
def test_subagent_usage_cache_is_skipped_when_config_file_is_missing(monkeypatch):
|
||||
@pytest.mark.parametrize("error", [FileNotFoundError("missing config"), ValueError("invalid config")])
|
||||
def test_subagent_usage_cache_is_skipped_when_default_config_cannot_load(monkeypatch, error):
|
||||
monkeypatch.setattr(
|
||||
task_tool_module,
|
||||
"get_app_config",
|
||||
MagicMock(side_effect=FileNotFoundError("missing config")),
|
||||
MagicMock(side_effect=error),
|
||||
)
|
||||
|
||||
assert task_tool_module._token_usage_cache_enabled(None) is False
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestTitleMiddlewareCoreLogic:
|
||||
assert middleware._should_generate_title(state) is False
|
||||
|
||||
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=12, model_name=None)
|
||||
_set_test_title_config(max_chars=12)
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
||||
|
||||
@@ -1,19 +1,14 @@
|
||||
"""Tests for TodoMiddleware context-loss detection."""
|
||||
|
||||
import asyncio
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from deerflow.agents.middlewares.todo_middleware import (
|
||||
TodoMiddleware,
|
||||
_completion_reminder_count,
|
||||
_format_todos,
|
||||
_has_tool_call_intent_or_error,
|
||||
_reminder_in_messages,
|
||||
_todos_in_messages,
|
||||
)
|
||||
@@ -27,35 +22,9 @@ def _reminder_msg():
|
||||
return HumanMessage(name="todo_reminder", content="reminder")
|
||||
|
||||
|
||||
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
|
||||
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
|
||||
|
||||
@property
|
||||
def seen_messages(self) -> list[list[Any]]:
|
||||
return self._seen_messages
|
||||
|
||||
def bind_tools(self, tools, *, tool_choice=None, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self._seen_messages.append(list(messages))
|
||||
return super()._generate(
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime():
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "test-thread", "run_id": "test-run"}
|
||||
return runtime
|
||||
|
||||
|
||||
def _make_runtime_for(thread_id: str, run_id: str):
|
||||
runtime = _make_runtime()
|
||||
runtime.context = {"thread_id": thread_id, "run_id": run_id}
|
||||
runtime.context = {"thread_id": "test-thread"}
|
||||
return runtime
|
||||
|
||||
|
||||
@@ -192,62 +161,10 @@ def _completion_reminder_msg():
|
||||
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
|
||||
|
||||
|
||||
def _todo_completion_reminders(messages):
|
||||
reminders = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder":
|
||||
reminders.append(message)
|
||||
return reminders
|
||||
|
||||
|
||||
def _ai_no_tool_calls():
|
||||
return AIMessage(content="I'm done!")
|
||||
|
||||
|
||||
def _ai_with_invalid_tool_calls():
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"type": "invalid_tool_call",
|
||||
"id": "write_file:36",
|
||||
"name": "write_file",
|
||||
"args": "{invalid",
|
||||
"error": "Failed to parse tool arguments",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _ai_with_raw_provider_tool_calls():
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[],
|
||||
invalid_tool_calls=[],
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "raw-tool-call",
|
||||
"type": "function",
|
||||
"function": {"name": "write_file", "arguments": '{"path":"report.md"}'},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _ai_with_legacy_function_call():
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}},
|
||||
)
|
||||
|
||||
|
||||
def _ai_with_tool_finish_reason():
|
||||
return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"})
|
||||
|
||||
|
||||
def _incomplete_todos():
|
||||
return [
|
||||
{"status": "completed", "content": "Step 1"},
|
||||
@@ -277,36 +194,6 @@ class TestCompletionReminderCount:
|
||||
assert _completion_reminder_count(msgs) == 1
|
||||
|
||||
|
||||
class TestToolCallIntentOrError:
|
||||
def test_false_for_plain_final_answer(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False
|
||||
|
||||
def test_true_for_structured_tool_calls(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True
|
||||
|
||||
def test_true_for_invalid_tool_calls(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True
|
||||
|
||||
def test_true_for_raw_provider_tool_calls(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True
|
||||
|
||||
def test_true_for_legacy_function_call(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True
|
||||
|
||||
def test_true_for_tool_finish_reason(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True
|
||||
|
||||
def test_langchain_ai_message_tool_fields_are_explicitly_handled(self):
|
||||
# Sentinel for LangChain compatibility: if future AIMessage versions add
|
||||
# new top-level tool/function-call fields, this test should fail. When
|
||||
# it does, update `_has_tool_call_intent_or_error()` so the completion
|
||||
# reminder guard explicitly decides whether each new field means "not a
|
||||
# clean final answer"; the helper has a matching comment pointing back
|
||||
# to this sentinel.
|
||||
tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())}
|
||||
assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"}
|
||||
|
||||
|
||||
class TestAfterModel:
|
||||
def test_returns_none_when_agent_still_using_tools(self):
|
||||
mw = TodoMiddleware()
|
||||
@@ -348,299 +235,68 @@ class TestAfterModel:
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_queues_reminder_and_jumps_to_model_when_incomplete(self):
|
||||
def test_injects_reminder_and_jumps_to_model_when_incomplete(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = mw.after_model(state, runtime)
|
||||
result = mw.after_model(state, _make_runtime())
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
assert "messages" not in result
|
||||
|
||||
request = MagicMock()
|
||||
request.runtime = runtime
|
||||
request.messages = state["messages"]
|
||||
request.override.return_value = "patched-request"
|
||||
handler = MagicMock(return_value="response")
|
||||
|
||||
assert mw.wrap_model_call(request, handler) == "response"
|
||||
request.override.assert_called_once()
|
||||
reminder = request.override.call_args.kwargs["messages"][-1]
|
||||
assert len(result["messages"]) == 1
|
||||
reminder = result["messages"][0]
|
||||
assert isinstance(reminder, HumanMessage)
|
||||
assert reminder.name == "todo_completion_reminder"
|
||||
assert reminder.additional_kwargs["hide_from_ui"] is True
|
||||
assert "Step 2" in reminder.content
|
||||
assert "Step 3" in reminder.content
|
||||
handler.assert_called_once_with("patched-request")
|
||||
|
||||
def test_reminder_lists_only_incomplete_items(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = mw.after_model(state, runtime)
|
||||
assert result is not None
|
||||
|
||||
request = MagicMock()
|
||||
request.runtime = runtime
|
||||
request.messages = state["messages"]
|
||||
request.override.return_value = "patched-request"
|
||||
mw.wrap_model_call(request, MagicMock(return_value="response"))
|
||||
content = request.override.call_args.kwargs["messages"][-1].content
|
||||
result = mw.after_model(state, _make_runtime())
|
||||
content = result["messages"][0].content
|
||||
assert "Step 1" not in content # completed — should not appear
|
||||
assert "Step 2" in content
|
||||
assert "Step 3" in content
|
||||
|
||||
def test_allows_exit_after_max_reminders(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [
|
||||
_completion_reminder_msg(),
|
||||
_completion_reminder_msg(),
|
||||
_ai_no_tool_calls(),
|
||||
],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, runtime) is not None
|
||||
assert mw.after_model(state, runtime) is not None
|
||||
assert mw.after_model(state, runtime) is None
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_still_sends_reminder_before_cap(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [
|
||||
_completion_reminder_msg(), # 1 reminder so far
|
||||
_ai_no_tool_calls(),
|
||||
],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, runtime) is not None
|
||||
result = mw.after_model(state, runtime)
|
||||
result = mw.after_model(state, _make_runtime())
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
|
||||
def test_does_not_trigger_for_invalid_tool_calls(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_with_invalid_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_does_not_trigger_for_raw_provider_tool_calls(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_with_raw_provider_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_does_not_trigger_for_legacy_function_call(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_with_legacy_function_call()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_does_not_trigger_for_tool_finish_reason(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_with_tool_finish_reason()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
|
||||
class TestAafterModel:
|
||||
def test_delegates_to_sync(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = asyncio.run(mw.aafter_model(state, runtime))
|
||||
result = asyncio.run(mw.aafter_model(state, _make_runtime()))
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
assert "messages" not in result
|
||||
|
||||
|
||||
class TestWrapModelCall:
|
||||
def test_no_pending_reminder_passthrough(self):
|
||||
mw = TodoMiddleware()
|
||||
request = MagicMock()
|
||||
request.runtime = _make_runtime()
|
||||
request.messages = [HumanMessage(content="hi")]
|
||||
handler = MagicMock(return_value="response")
|
||||
|
||||
assert mw.wrap_model_call(request, handler) == "response"
|
||||
request.override.assert_not_called()
|
||||
handler.assert_called_once_with(request)
|
||||
|
||||
def test_pending_reminder_is_injected_once(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
mw.after_model(state, runtime)
|
||||
|
||||
request = MagicMock()
|
||||
request.runtime = runtime
|
||||
request.messages = state["messages"]
|
||||
request.override.return_value = "patched-request"
|
||||
handler = MagicMock(return_value="response")
|
||||
|
||||
assert mw.wrap_model_call(request, handler) == "response"
|
||||
injected_messages = request.override.call_args.kwargs["messages"]
|
||||
assert injected_messages[-1].name == "todo_completion_reminder"
|
||||
|
||||
request.override.reset_mock()
|
||||
handler.reset_mock()
|
||||
handler.return_value = "second-response"
|
||||
assert mw.wrap_model_call(request, handler) == "second-response"
|
||||
request.override.assert_not_called()
|
||||
handler.assert_called_once_with(request)
|
||||
|
||||
|
||||
class TestTodoMiddlewareAgentGraphIntegration:
|
||||
def test_completion_reminder_is_transient_in_real_agent_graph(self):
|
||||
mw = TodoMiddleware()
|
||||
model = _CapturingFakeMessagesListChatModel(
|
||||
responses=[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "write_todos",
|
||||
"id": "todos-1",
|
||||
"args": {
|
||||
"todos": [
|
||||
{"content": "Step 1", "status": "completed"},
|
||||
{"content": "Step 2", "status": "pending"},
|
||||
]
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="premature final 1"),
|
||||
AIMessage(content="premature final 2"),
|
||||
AIMessage(content="premature final 3"),
|
||||
],
|
||||
)
|
||||
graph = create_agent(model=model, tools=[], middleware=[mw])
|
||||
|
||||
result = graph.invoke(
|
||||
{"messages": [("user", "finish all todos")]},
|
||||
context={"thread_id": "integration-thread", "run_id": "integration-run"},
|
||||
)
|
||||
|
||||
assert len(model.seen_messages) == 4
|
||||
reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages]
|
||||
assert reminders_by_call[0] == []
|
||||
assert reminders_by_call[1] == []
|
||||
assert len(reminders_by_call[2]) == 1
|
||||
assert len(reminders_by_call[3]) == 1
|
||||
assert "Step 1" not in reminders_by_call[2][0].content
|
||||
assert "Step 2" in reminders_by_call[2][0].content
|
||||
|
||||
persisted_reminders = _todo_completion_reminders(result["messages"])
|
||||
assert persisted_reminders == []
|
||||
assert result["messages"][-1].content == "premature final 3"
|
||||
assert result["todos"] == [
|
||||
{"content": "Step 1", "status": "completed"},
|
||||
{"content": "Step 2", "status": "pending"},
|
||||
]
|
||||
assert mw._pending_completion_reminders == {}
|
||||
assert mw._completion_reminder_counts == {}
|
||||
|
||||
|
||||
class TestRunScopedReminderCleanup:
|
||||
def test_before_agent_clears_stale_count_without_pending_reminder(self):
|
||||
mw = TodoMiddleware()
|
||||
stale_runtime = _make_runtime()
|
||||
stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"}
|
||||
current_runtime = _make_runtime()
|
||||
current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"}
|
||||
other_thread_runtime = _make_runtime()
|
||||
other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"}
|
||||
|
||||
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
||||
assert mw.after_model(state, stale_runtime) is not None
|
||||
assert mw.after_model(state, other_thread_runtime) is not None
|
||||
|
||||
# Simulate a model call that drained the pending message, followed by an
|
||||
# abnormal run end where after_agent did not clear the reminder count.
|
||||
assert mw._drain_completion_reminders(stale_runtime)
|
||||
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1
|
||||
|
||||
mw.before_agent({}, current_runtime)
|
||||
|
||||
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
|
||||
assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1
|
||||
|
||||
def test_size_guard_prunes_oldest_count_only_reminder_state(self):
|
||||
mw = TodoMiddleware()
|
||||
mw._MAX_COMPLETION_REMINDER_KEYS = 2
|
||||
first_runtime = _make_runtime_for("thread-a", "run-a")
|
||||
second_runtime = _make_runtime_for("thread-b", "run-b")
|
||||
third_runtime = _make_runtime_for("thread-c", "run-c")
|
||||
|
||||
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
||||
assert mw.after_model(state, first_runtime) is not None
|
||||
|
||||
# Simulate the normal model request path: pending reminder is consumed,
|
||||
# but the run count remains until after_agent() or stale cleanup.
|
||||
assert mw._drain_completion_reminders(first_runtime)
|
||||
assert mw._completion_reminder_count_for_runtime(first_runtime) == 1
|
||||
|
||||
assert mw.after_model(state, second_runtime) is not None
|
||||
assert mw.after_model(state, third_runtime) is not None
|
||||
|
||||
assert mw._completion_reminder_count_for_runtime(first_runtime) == 0
|
||||
assert mw._completion_reminder_count_for_runtime(second_runtime) == 1
|
||||
assert mw._completion_reminder_count_for_runtime(third_runtime) == 1
|
||||
assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order
|
||||
|
||||
def test_size_guard_prunes_pending_and_count_state_together(self):
|
||||
mw = TodoMiddleware()
|
||||
mw._MAX_COMPLETION_REMINDER_KEYS = 1
|
||||
stale_runtime = _make_runtime_for("thread-a", "run-a")
|
||||
current_runtime = _make_runtime_for("thread-b", "run-b")
|
||||
|
||||
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
||||
assert mw.after_model(state, stale_runtime) is not None
|
||||
assert mw.after_model(state, current_runtime) is not None
|
||||
|
||||
assert mw._drain_completion_reminders(stale_runtime) == []
|
||||
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
|
||||
assert mw._completion_reminder_count_for_runtime(current_runtime) == 1
|
||||
|
||||
|
||||
class TestAwrapModelCall:
|
||||
def test_async_pending_reminder_is_injected(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
mw.after_model(state, runtime)
|
||||
|
||||
request = MagicMock()
|
||||
request.runtime = runtime
|
||||
request.messages = state["messages"]
|
||||
request.override.return_value = "patched-request"
|
||||
handler = AsyncMock(return_value="response")
|
||||
|
||||
result = asyncio.run(mw.awrap_model_call(request, handler))
|
||||
assert result == "response"
|
||||
injected_messages = request.override.call_args.kwargs["messages"]
|
||||
assert injected_messages[-1].name == "todo_completion_reminder"
|
||||
handler.assert_awaited_once_with("patched-request")
|
||||
assert result["messages"][0].name == "todo_completion_reminder"
|
||||
|
||||
Generated
+93
-18
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 2
|
||||
revision = 3
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14' and sys_platform == 'win32'",
|
||||
@@ -14,6 +14,7 @@ resolution-markers = [
|
||||
members = [
|
||||
"deer-flow",
|
||||
"deerflow-harness",
|
||||
"deerflow-storage",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -136,6 +137,18 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/62/29/2f8418269e46454a26171bfdd6a055d74febf32234e474930f2f60a17145/aiohttp-3.13.5-cp314-cp314t-win_amd64.whl", hash = "sha256:18a2f6c1182c51baa1d28d68fea51513cb2a76612f038853c0ad3c145423d3d9", size = 505441, upload-time = "2026-03-31T22:00:12.791Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiomysql"
|
||||
version = "0.3.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pymysql" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/29/e0/302aeffe8d90853556f47f3106b89c16cc2ec2a4d269bdfd82e3f4ae12cc/aiomysql-0.3.2.tar.gz", hash = "sha256:72d15ef5cfc34c03468eb41e1b90adb9fd9347b0b589114bd23ead569a02ac1a", size = 108311, upload-time = "2025-10-22T00:15:21.278Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/af/aae0153c3e28712adaf462328f6c7a3c196a1c1c27b491de4377dd3e6b52/aiomysql-0.3.2-py3-none-any.whl", hash = "sha256:c82c5ba04137d7afd5c693a258bea8ead2aad77101668044143a991e04632eb2", size = 71834, upload-time = "2025-10-22T00:15:15.905Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "aiosignal"
|
||||
version = "1.4.0"
|
||||
@@ -746,6 +759,7 @@ source = { virtual = "." }
|
||||
dependencies = [
|
||||
{ name = "bcrypt" },
|
||||
{ name = "deerflow-harness" },
|
||||
{ name = "deerflow-storage" },
|
||||
{ name = "dingtalk-stream" },
|
||||
{ name = "email-validator" },
|
||||
{ name = "fastapi" },
|
||||
@@ -763,11 +777,12 @@ dependencies = [
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
discord = [
|
||||
{ name = "discord-py" },
|
||||
mysql = [
|
||||
{ name = "deerflow-storage", extra = ["mysql"] },
|
||||
]
|
||||
postgres = [
|
||||
{ name = "deerflow-harness", extra = ["postgres"] },
|
||||
{ name = "deerflow-storage", extra = ["postgres"] },
|
||||
]
|
||||
|
||||
[package.dev-dependencies]
|
||||
@@ -783,8 +798,10 @@ requires-dist = [
|
||||
{ name = "bcrypt", specifier = ">=4.0.0" },
|
||||
{ name = "deerflow-harness", editable = "packages/harness" },
|
||||
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
||||
{ name = "deerflow-storage", editable = "packages/storage" },
|
||||
{ name = "deerflow-storage", extras = ["mysql"], marker = "extra == 'mysql'", editable = "packages/storage" },
|
||||
{ name = "deerflow-storage", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/storage" },
|
||||
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
||||
{ name = "discord-py", marker = "extra == 'discord'", specifier = ">=2.7.0" },
|
||||
{ name = "email-validator", specifier = ">=2.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.115.0" },
|
||||
{ name = "httpx", specifier = ">=0.28.0" },
|
||||
@@ -799,7 +816,7 @@ requires-dist = [
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
||||
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
||||
]
|
||||
provides-extras = ["postgres", "discord"]
|
||||
provides-extras = ["postgres", "mysql"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
@@ -905,6 +922,54 @@ requires-dist = [
|
||||
]
|
||||
provides-extras = ["ollama", "postgres", "pymupdf"]
|
||||
|
||||
[[package]]
|
||||
name = "deerflow-storage"
|
||||
version = "0.1.0"
|
||||
source = { editable = "packages/storage" }
|
||||
dependencies = [
|
||||
{ name = "alembic" },
|
||||
{ name = "dotenv" },
|
||||
{ name = "langgraph" },
|
||||
{ name = "pydantic" },
|
||||
{ name = "pyyaml" },
|
||||
{ name = "sqlalchemy", extra = ["asyncio"] },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
mysql = [
|
||||
{ name = "aiomysql" },
|
||||
{ name = "langgraph-checkpoint-mysql" },
|
||||
]
|
||||
postgres = [
|
||||
{ name = "asyncpg" },
|
||||
{ name = "langgraph-checkpoint-postgres" },
|
||||
{ name = "psycopg", extra = ["binary"] },
|
||||
{ name = "psycopg-pool" },
|
||||
]
|
||||
sqlite = [
|
||||
{ name = "aiosqlite" },
|
||||
{ name = "langgraph-checkpoint-sqlite" },
|
||||
]
|
||||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "aiomysql", marker = "extra == 'mysql'", specifier = ">=0.2" },
|
||||
{ name = "aiosqlite", marker = "extra == 'sqlite'", specifier = ">=0.22.1" },
|
||||
{ name = "alembic", specifier = ">=1.13" },
|
||||
{ name = "asyncpg", marker = "extra == 'postgres'", specifier = ">=0.29" },
|
||||
{ name = "dotenv", specifier = ">=0.9.9" },
|
||||
{ name = "langgraph", specifier = ">=1.1.9" },
|
||||
{ name = "langgraph-checkpoint-mysql", marker = "extra == 'mysql'", specifier = ">=3.0.0" },
|
||||
{ name = "langgraph-checkpoint-postgres", marker = "extra == 'postgres'", specifier = ">=3.0.5" },
|
||||
{ name = "langgraph-checkpoint-sqlite", marker = "extra == 'sqlite'", specifier = ">=3.0.3" },
|
||||
{ name = "psycopg", extras = ["binary"], marker = "extra == 'postgres'", specifier = ">=3.3.3" },
|
||||
{ name = "psycopg-pool", marker = "extra == 'postgres'", specifier = ">=3.3.0" },
|
||||
{ name = "pydantic", specifier = ">=2.12.5" },
|
||||
{ name = "pyyaml", specifier = ">=6.0.3" },
|
||||
{ name = "sqlalchemy", extras = ["asyncio"], specifier = ">=2.0,<3.0" },
|
||||
]
|
||||
provides-extras = ["postgres", "mysql", "sqlite"]
|
||||
|
||||
[[package]]
|
||||
name = "defusedxml"
|
||||
version = "0.7.1"
|
||||
@@ -927,19 +992,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/44/102dede3f371277598df6aa9725b82e3add068c729333c7a5dbc12764579/dingtalk_stream-0.24.3-py3-none-any.whl", hash = "sha256:2160403656985962878bf60cdf5adf41619f21067348e06f07a7c7eebf5943ad", size = 27813, upload-time = "2025-10-24T09:36:57.497Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "discord-py"
|
||||
version = "2.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "audioop-lts", marker = "python_full_version >= '3.13'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ef/57/9a2d9abdabdc9db8ef28ce0cf4129669e1c8717ba28d607b5ba357c4de3b/discord_py-2.7.1.tar.gz", hash = "sha256:24d5e6a45535152e4b98148a9dd6b550d25dc2c9fb41b6d670319411641249da", size = 1106326, upload-time = "2026-03-03T18:40:46.24Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/a7/17208c3b3f92319e7fad259f1c6d5a5baf8fd0654c54846ced329f83c3eb/discord_py-2.7.1-py3-none-any.whl", hash = "sha256:849dca2c63b171146f3a7f3f8acc04248098e9e6203412ce3cf2745f284f7439", size = 1227550, upload-time = "2026-03-03T18:40:44.492Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distro"
|
||||
version = "1.9.0"
|
||||
@@ -1931,6 +1983,20 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/b9/5a/6dba29dd89b0a46ae21c707da0f9d17e94f27d3e481ed15bc99d6bd20aa6/langgraph_checkpoint-4.0.2-py3-none-any.whl", hash = "sha256:59b0f29216128a629c58dd07c98aa004f82f51805d5573126ffb419b753ff253", size = 51000, upload-time = "2026-04-15T21:02:59.096Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint-mysql"
|
||||
version = "3.0.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "langgraph-checkpoint" },
|
||||
{ name = "orjson" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e0/4e/0a6c78e5d3f2ca1525903c2363e721873594b6b77dd83537a6369193c474/langgraph_checkpoint_mysql-3.0.0.tar.gz", hash = "sha256:006aaa089f4c2fbd7b2c113b800ccd3dbb95f92203e656451677256b4b4f880f", size = 213142, upload-time = "2026-01-23T11:11:15.74Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/08/68/343103a7fae05523f9cecabbec2babdb737e66b4bf6ea48ae00c685ed11c/langgraph_checkpoint_mysql-3.0.0-py3-none-any.whl", hash = "sha256:7560ccd16e7596a047e15a307cec12dbd88fdcaab45a75759e5c6adef22a27d1", size = 38009, upload-time = "2026-01-23T11:11:14.697Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "langgraph-checkpoint-postgres"
|
||||
version = "3.0.5"
|
||||
@@ -3459,6 +3525,15 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e6/38/84bf29f4dd72e6c450546df6ca8f53021f764fd945ba67dcc235d39bc20e/pymupdf4llm-1.27.2.3-py3-none-any.whl", hash = "sha256:bd724b79fa3f06a5b28d7a65f7acfa8de56e04bdb603ac2d6dff315e0d151aaa", size = 77348, upload-time = "2026-04-24T14:11:04.305Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pymysql"
|
||||
version = "1.1.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7f/ec/8d45c920e90445f0b75c590b32851853ed319763b0d8dff8d283052da8cf/pymysql-1.1.3.tar.gz", hash = "sha256:e70ebf2047a4edf6138cf79c68ad418ef620af65900aa585c5e8bfc95044d43a", size = 48207, upload-time = "2026-05-01T09:09:54.532Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/dc/9085f3d6f497e9b25fb40d6e8ecef3ddbb5cf977a949b933624a299f5c16/pymysql-1.1.3-py3-none-any.whl", hash = "sha256:8164ba62c552f6105f3b11753352d0f16b90d1703ba67d81923d5a8a5d1c5289", size = 45356, upload-time = "2026-05-01T09:09:53.316Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pypdfium2"
|
||||
version = "5.7.1"
|
||||
|
||||
@@ -1029,14 +1029,6 @@ run_events:
|
||||
# client_secret: $DINGTALK_CLIENT_SECRET
|
||||
# allowed_users: [] # empty = allow all
|
||||
# card_template_id: "" # Optional: AI Card template ID for streaming updates
|
||||
#
|
||||
# discord:
|
||||
# enabled: false
|
||||
# bot_token: $DISCORD_BOT_TOKEN
|
||||
# allowed_guilds: [] # empty = allow all guilds; can also be a single guild ID
|
||||
# mention_only: false # If true, only respond when the bot is mentioned
|
||||
# allowed_channels: [] # Optional: channel IDs exempt from mention_only (bot responds without mention)
|
||||
# thread_mode: false # If true, group a channel conversation into a thread
|
||||
|
||||
# ============================================================================
|
||||
# Guardrails Configuration
|
||||
|
||||
+3
-21
@@ -28,10 +28,6 @@ http {
|
||||
set $gateway_upstream gateway:8001;
|
||||
set $frontend_upstream frontend:3000;
|
||||
|
||||
# Default proxy settings for all locations (streaming/SSE support)
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
|
||||
# Keep the unified nginx endpoint same-origin by default. When split
|
||||
# frontend/backend or port-forwarded deployments need browser CORS,
|
||||
# configure the Gateway allowlist with GATEWAY_CORS_ORIGINS so CORS and
|
||||
@@ -53,6 +49,8 @@ http {
|
||||
proxy_set_header Connection '';
|
||||
|
||||
# SSE/Streaming support
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_set_header X-Accel-Buffering no;
|
||||
|
||||
# Timeouts for long-running requests
|
||||
@@ -72,7 +70,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: Memory endpoint
|
||||
@@ -83,7 +80,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: MCP configuration endpoint
|
||||
@@ -94,7 +90,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: Skills configuration endpoint
|
||||
@@ -105,7 +100,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: Agents endpoint
|
||||
@@ -116,7 +110,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: Uploads endpoint
|
||||
@@ -131,8 +124,6 @@ http {
|
||||
# Large file upload support
|
||||
client_max_body_size 100M;
|
||||
proxy_request_buffering off;
|
||||
|
||||
# Disable response buffering to avoid permission errors
|
||||
}
|
||||
|
||||
# Custom API: Other endpoints under /api/threads
|
||||
@@ -143,7 +134,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# API Documentation: Swagger UI
|
||||
@@ -154,7 +144,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# API Documentation: ReDoc
|
||||
@@ -165,7 +154,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# API Documentation: OpenAPI Schema
|
||||
@@ -176,7 +164,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Health check endpoint (gateway)
|
||||
@@ -187,7 +174,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# ── Provisioner API (sandbox management) ────────────────────────
|
||||
@@ -201,7 +187,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Catch-all for /api/ routes not covered above (e.g. /api/v1/auth/*).
|
||||
@@ -213,9 +198,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# Disable buffering to avoid permission errors when nginx
|
||||
# runs as a non-root user (e.g. local development).
|
||||
}
|
||||
|
||||
# All other requests go to frontend
|
||||
@@ -238,4 +220,4 @@ http {
|
||||
proxy_read_timeout 600s;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -70,11 +70,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# Disable buffering to avoid permission errors when nginx
|
||||
# runs as a non-root user (e.g. local development).
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Memory endpoint
|
||||
@@ -85,9 +80,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: MCP configuration endpoint
|
||||
@@ -98,9 +90,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Skills configuration endpoint
|
||||
@@ -111,9 +100,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Agents endpoint
|
||||
@@ -124,9 +110,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Uploads endpoint
|
||||
@@ -141,10 +124,6 @@ http {
|
||||
# Large file upload support
|
||||
client_max_body_size 100M;
|
||||
proxy_request_buffering off;
|
||||
|
||||
# Disable response buffering to avoid permission errors
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Other endpoints under /api/threads
|
||||
@@ -155,9 +134,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# API Documentation: Swagger UI
|
||||
@@ -168,9 +144,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# API Documentation: ReDoc
|
||||
@@ -181,9 +154,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# API Documentation: OpenAPI Schema
|
||||
@@ -194,9 +164,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Health check endpoint (gateway)
|
||||
@@ -207,9 +174,6 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Catch-all for any /api/* prefix not matched by a more specific block above.
|
||||
@@ -229,11 +193,6 @@ http {
|
||||
# Auth endpoints set HttpOnly cookies — make sure nginx doesn't
|
||||
# strip the Set-Cookie header from upstream responses.
|
||||
proxy_pass_header Set-Cookie;
|
||||
|
||||
# Disable buffering to avoid permission errors when nginx
|
||||
# runs as a non-root user (e.g. local development).
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# All other requests go to frontend
|
||||
|
||||
@@ -66,7 +66,6 @@ export default function AgentChatPage() {
|
||||
thread,
|
||||
pendingUsageMessages,
|
||||
sendMessage,
|
||||
isUploading,
|
||||
isHistoryLoading,
|
||||
hasMoreHistory,
|
||||
loadMoreHistory,
|
||||
@@ -107,11 +106,7 @@ export default function AgentChatPage() {
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
(message: PromptInputMessage) => {
|
||||
const sendPromise = sendMessage(threadId, message, { agent_name });
|
||||
if (message.files.length > 0) {
|
||||
return sendPromise;
|
||||
}
|
||||
void sendPromise;
|
||||
void sendMessage(threadId, message, { agent_name });
|
||||
},
|
||||
[sendMessage, threadId, agent_name],
|
||||
);
|
||||
@@ -248,10 +243,7 @@ export default function AgentChatPage() {
|
||||
<AgentWelcome agent={agent} agentName={agent_name} />
|
||||
)
|
||||
}
|
||||
disabled={
|
||||
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||
|
||||
isUploading
|
||||
}
|
||||
disabled={env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true"}
|
||||
onContextChange={(context) => setSettings("context", context)}
|
||||
onSubmit={handleSubmit}
|
||||
onStop={handleStop}
|
||||
|
||||
@@ -109,11 +109,7 @@ export default function ChatPage() {
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
(message: PromptInputMessage) => {
|
||||
const sendPromise = sendMessage(threadId, message);
|
||||
if (message.files.length > 0) {
|
||||
return sendPromise;
|
||||
}
|
||||
void sendPromise;
|
||||
void sendMessage(threadId, message);
|
||||
},
|
||||
[sendMessage, threadId],
|
||||
);
|
||||
|
||||
@@ -499,10 +499,6 @@ export const PromptInput = ({
|
||||
// Keep a ref to files for cleanup on unmount (avoids stale closure)
|
||||
const filesRef = useRef(files);
|
||||
filesRef.current = files;
|
||||
const providerTextRef = useRef("");
|
||||
if (usingProvider) {
|
||||
providerTextRef.current = controller.textInput.value;
|
||||
}
|
||||
|
||||
const openFileDialogLocal = useCallback(() => {
|
||||
inputRef.current?.click();
|
||||
@@ -772,24 +768,6 @@ export const PromptInput = ({
|
||||
}
|
||||
|
||||
// Convert blob URLs to data URLs asynchronously
|
||||
const submittedFileIds = files.map((file) => file.id);
|
||||
const clearSubmittedState = () => {
|
||||
const currentFileIds = new Set(filesRef.current.map((file) => file.id));
|
||||
const submittedFileIdsStillPresent = submittedFileIds.filter((id) =>
|
||||
currentFileIds.has(id),
|
||||
);
|
||||
if (submittedFileIdsStillPresent.length === filesRef.current.length) {
|
||||
clear();
|
||||
} else {
|
||||
for (const id of submittedFileIdsStillPresent) {
|
||||
remove(id);
|
||||
}
|
||||
}
|
||||
if (usingProvider && providerTextRef.current === text) {
|
||||
controller.textInput.clear();
|
||||
}
|
||||
};
|
||||
|
||||
Promise.all(
|
||||
files.map(async ({ id, ...item }) => {
|
||||
if (item.file instanceof File) {
|
||||
@@ -815,14 +793,20 @@ export const PromptInput = ({
|
||||
if (result instanceof Promise) {
|
||||
result
|
||||
.then(() => {
|
||||
clearSubmittedState();
|
||||
clear();
|
||||
if (usingProvider) {
|
||||
controller.textInput.clear();
|
||||
}
|
||||
})
|
||||
.catch(() => {
|
||||
// Don't clear on error - user may want to retry
|
||||
});
|
||||
} else {
|
||||
// Sync function completed without throwing, clear attachments
|
||||
clearSubmittedState();
|
||||
clear();
|
||||
if (usingProvider) {
|
||||
controller.textInput.clear();
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// Don't clear on error - user may want to retry
|
||||
|
||||
@@ -110,7 +110,6 @@ export function InputBox({
|
||||
threadId,
|
||||
initialValue,
|
||||
onContextChange,
|
||||
onFollowupsVisibilityChange,
|
||||
onSubmit,
|
||||
onStop,
|
||||
...props
|
||||
@@ -143,8 +142,7 @@ export function InputBox({
|
||||
reasoning_effort?: "minimal" | "low" | "medium" | "high";
|
||||
},
|
||||
) => void;
|
||||
onFollowupsVisibilityChange?: (visible: boolean) => void;
|
||||
onSubmit?: (message: PromptInputMessage) => void | Promise<void>;
|
||||
onSubmit?: (message: PromptInputMessage) => void;
|
||||
onStop?: () => void;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
@@ -253,12 +251,12 @@ export function InputBox({
|
||||
);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
(message: PromptInputMessage) => {
|
||||
async (message: PromptInputMessage) => {
|
||||
if (status === "streaming") {
|
||||
onStop?.();
|
||||
return;
|
||||
}
|
||||
if (!message.text.trim() && message.files.length === 0) {
|
||||
if (!message.text) {
|
||||
return;
|
||||
}
|
||||
setFollowups([]);
|
||||
@@ -276,14 +274,11 @@ export function InputBox({
|
||||
selectedModel?.supports_thinking ?? false,
|
||||
),
|
||||
});
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
setTimeout(() => {
|
||||
Promise.resolve(onSubmit?.(message)).then(resolve).catch(reject);
|
||||
}, 0);
|
||||
});
|
||||
setTimeout(() => onSubmit?.(message), 0);
|
||||
return;
|
||||
}
|
||||
|
||||
return onSubmit?.(message);
|
||||
onSubmit?.(message);
|
||||
},
|
||||
[
|
||||
context,
|
||||
@@ -353,14 +348,6 @@ export function InputBox({
|
||||
!followupsHidden &&
|
||||
(followupsLoading || followups.length > 0);
|
||||
|
||||
useEffect(() => {
|
||||
onFollowupsVisibilityChange?.(showFollowups);
|
||||
}, [onFollowupsVisibilityChange, showFollowups]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => onFollowupsVisibilityChange?.(false);
|
||||
}, [onFollowupsVisibilityChange]);
|
||||
|
||||
useEffect(() => {
|
||||
messagesRef.current = thread.messages;
|
||||
}, [thread.messages]);
|
||||
|
||||
@@ -26,13 +26,6 @@ export type MessageGroup =
|
||||
| AssistantClarificationGroup
|
||||
| AssistantSubagentGroup;
|
||||
|
||||
const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([
|
||||
"summary",
|
||||
"loop_warning",
|
||||
"todo_reminder",
|
||||
"todo_completion_reminder",
|
||||
]);
|
||||
|
||||
export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
||||
if (messages.length === 0) {
|
||||
return [];
|
||||
@@ -60,6 +53,10 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (message.name === "todo_reminder") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (message.type === "human") {
|
||||
groups.push({ id: message.id, type: "human", messages: [message] });
|
||||
continue;
|
||||
@@ -371,8 +368,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
|
||||
export function isHiddenFromUIMessage(message: Message) {
|
||||
return (
|
||||
message.additional_kwargs?.hide_from_ui === true ||
|
||||
(typeof message.name === "string" &&
|
||||
HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name))
|
||||
message.name === "summary" ||
|
||||
message.name === "loop_warning"
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -45,60 +45,15 @@ type SendMessageOptions = {
|
||||
additionalKwargs?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
function isNonEmptyString(value: string | undefined): value is string {
|
||||
return typeof value === "string" && value.length > 0;
|
||||
}
|
||||
|
||||
function messageIdentity(message: Message): string | undefined {
|
||||
if (
|
||||
"tool_call_id" in message &&
|
||||
typeof message.tool_call_id === "string" &&
|
||||
message.tool_call_id.length > 0
|
||||
) {
|
||||
return `tool:${message.tool_call_id}`;
|
||||
}
|
||||
if (typeof message.id === "string" && message.id.length > 0) {
|
||||
return `message:${message.id}`;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
|
||||
const lastIndexByIdentity = new Map<string, number>();
|
||||
|
||||
messages.forEach((message, index) => {
|
||||
const identity = messageIdentity(message);
|
||||
if (identity) {
|
||||
lastIndexByIdentity.set(identity, index);
|
||||
}
|
||||
});
|
||||
|
||||
return messages.filter((message, index) => {
|
||||
const identity = messageIdentity(message);
|
||||
return !identity || lastIndexByIdentity.get(identity) === index;
|
||||
});
|
||||
}
|
||||
|
||||
function findLatestUnloadedRunIndex(
|
||||
runs: Run[],
|
||||
loadedRunIds: ReadonlySet<string>,
|
||||
): number {
|
||||
for (let i = runs.length - 1; i >= 0; i--) {
|
||||
const run = runs[i];
|
||||
if (run && !loadedRunIds.has(run.run_id)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
export function mergeMessages(
|
||||
function mergeMessages(
|
||||
historyMessages: Message[],
|
||||
threadMessages: Message[],
|
||||
optimisticMessages: Message[],
|
||||
): Message[] {
|
||||
const threadMessageIds = new Set(
|
||||
threadMessages.map(messageIdentity).filter(isNonEmptyString),
|
||||
threadMessages
|
||||
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
|
||||
.filter(Boolean),
|
||||
);
|
||||
|
||||
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
||||
@@ -110,19 +65,28 @@ export function mergeMessages(
|
||||
if (!msg) {
|
||||
continue;
|
||||
}
|
||||
const identity = messageIdentity(msg);
|
||||
if (identity && threadMessageIds.has(identity)) {
|
||||
if (
|
||||
(msg?.id && threadMessageIds.has(msg.id)) ||
|
||||
("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id))
|
||||
) {
|
||||
cutoff = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return dedupeMessagesByIdentity([
|
||||
return [
|
||||
...historyMessages.slice(0, cutoff),
|
||||
...threadMessages,
|
||||
...optimisticMessages,
|
||||
]);
|
||||
];
|
||||
}
|
||||
|
||||
function messageIdentity(message: Message): string | undefined {
|
||||
if ("tool_call_id" in message) {
|
||||
return message.tool_call_id;
|
||||
}
|
||||
return message.id;
|
||||
}
|
||||
|
||||
function getMessagesAfterBaseline(
|
||||
@@ -663,105 +627,48 @@ export function useThreadHistory(threadId: string) {
|
||||
const runsRef = useRef(runs.data ?? []);
|
||||
const indexRef = useRef(-1);
|
||||
const loadingRef = useRef(false);
|
||||
const pendingLoadRef = useRef(false);
|
||||
const loadingRunIdRef = useRef<string | null>(null);
|
||||
const loadedRunIdsRef = useRef<Set<string>>(new Set());
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
|
||||
loadingRef.current = loading;
|
||||
const loadMessages = useCallback(async () => {
|
||||
if (loadingRef.current) {
|
||||
const pendingRunIndex = findLatestUnloadedRunIndex(
|
||||
runsRef.current,
|
||||
loadedRunIdsRef.current,
|
||||
);
|
||||
const pendingRun = runsRef.current[pendingRunIndex];
|
||||
if (pendingRun && pendingRun.run_id !== loadingRunIdRef.current) {
|
||||
pendingLoadRef.current = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (runsRef.current.length === 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
loadingRef.current = true;
|
||||
setLoading(true);
|
||||
|
||||
const run = runsRef.current[indexRef.current];
|
||||
if (!run || loadingRef.current) {
|
||||
return;
|
||||
}
|
||||
try {
|
||||
do {
|
||||
pendingLoadRef.current = false;
|
||||
|
||||
const nextRunIndex = findLatestUnloadedRunIndex(
|
||||
runsRef.current,
|
||||
loadedRunIdsRef.current,
|
||||
);
|
||||
indexRef.current = nextRunIndex;
|
||||
|
||||
const run = runsRef.current[nextRunIndex];
|
||||
if (!run) {
|
||||
indexRef.current = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
const requestThreadId = threadIdRef.current;
|
||||
loadingRunIdRef.current = run.run_id;
|
||||
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(requestThreadId)}/runs/${encodeURIComponent(run.run_id)}/messages`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
credentials: "include",
|
||||
setLoading(true);
|
||||
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
).then((res) => {
|
||||
return res.json();
|
||||
});
|
||||
const _messages = result.data
|
||||
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||
.map((m) => m.content);
|
||||
if (threadIdRef.current !== requestThreadId) {
|
||||
return;
|
||||
}
|
||||
setMessages((prev) =>
|
||||
dedupeMessagesByIdentity([..._messages, ...prev]),
|
||||
);
|
||||
loadedRunIdsRef.current.add(run.run_id);
|
||||
indexRef.current = findLatestUnloadedRunIndex(
|
||||
runsRef.current,
|
||||
loadedRunIdsRef.current,
|
||||
);
|
||||
} while (pendingLoadRef.current);
|
||||
credentials: "include",
|
||||
},
|
||||
).then((res) => {
|
||||
return res.json();
|
||||
});
|
||||
const _messages = result.data
|
||||
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||
.map((m) => m.content);
|
||||
setMessages((prev) => [..._messages, ...prev]);
|
||||
indexRef.current -= 1;
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
} finally {
|
||||
loadingRef.current = false;
|
||||
loadingRunIdRef.current = null;
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
useEffect(() => {
|
||||
const threadChanged = threadIdRef.current !== threadId;
|
||||
threadIdRef.current = threadId;
|
||||
|
||||
if (threadChanged) {
|
||||
runsRef.current = [];
|
||||
indexRef.current = -1;
|
||||
pendingLoadRef.current = false;
|
||||
loadingRunIdRef.current = null;
|
||||
loadedRunIdsRef.current = new Set();
|
||||
loadingRef.current = false;
|
||||
setLoading(false);
|
||||
setMessages([]);
|
||||
}
|
||||
|
||||
if (runs.data && runs.data.length > 0) {
|
||||
runsRef.current = runs.data ?? [];
|
||||
indexRef.current = findLatestUnloadedRunIndex(
|
||||
runs.data,
|
||||
loadedRunIdsRef.current,
|
||||
);
|
||||
indexRef.current = runs.data.length - 1;
|
||||
}
|
||||
loadMessages().catch(() => {
|
||||
toast.error("Failed to load thread history.");
|
||||
@@ -770,7 +677,7 @@ export function useThreadHistory(threadId: string) {
|
||||
|
||||
const appendMessages = useCallback((_messages: Message[]) => {
|
||||
setMessages((prev) => {
|
||||
return dedupeMessagesByIdentity([...prev, ..._messages]);
|
||||
return [...prev, ..._messages];
|
||||
});
|
||||
}, []);
|
||||
const hasMore = indexRef.current >= 0 || !runs.data;
|
||||
|
||||
@@ -48,66 +48,4 @@ test.describe("Chat workspace", () => {
|
||||
timeout: 10_000,
|
||||
});
|
||||
});
|
||||
|
||||
test("keeps attachments visible while upload submit is pending", async ({
|
||||
page,
|
||||
}) => {
|
||||
let releaseUpload!: () => void;
|
||||
const uploadCanFinish = new Promise<void>((resolve) => {
|
||||
releaseUpload = resolve;
|
||||
});
|
||||
let uploadStarted!: () => void;
|
||||
const uploadStartedPromise = new Promise<void>((resolve) => {
|
||||
uploadStarted = resolve;
|
||||
});
|
||||
|
||||
await page.route("**/api/threads/*/uploads", async (route) => {
|
||||
uploadStarted();
|
||||
await uploadCanFinish;
|
||||
return route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
success: true,
|
||||
message: "Uploaded",
|
||||
files: [
|
||||
{
|
||||
filename: "report.docx",
|
||||
size: 12,
|
||||
path: "report.docx",
|
||||
virtual_path: "/mnt/user-data/uploads/report.docx",
|
||||
artifact_url: "/api/threads/test/uploads/report.docx",
|
||||
extension: ".docx",
|
||||
},
|
||||
],
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
await page.goto("/workspace/chats/new");
|
||||
|
||||
const textarea = page.getByPlaceholder(/how can i assist you/i);
|
||||
await expect(textarea).toBeVisible({ timeout: 15_000 });
|
||||
const promptForm = page.locator("form").filter({ has: textarea });
|
||||
|
||||
await page.getByLabel("Upload files").setInputFiles({
|
||||
name: "report.docx",
|
||||
mimeType:
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
buffer: Buffer.from("fake docx"),
|
||||
});
|
||||
await expect(promptForm.getByText("report.docx")).toBeVisible();
|
||||
|
||||
await textarea.fill("Summarize this document");
|
||||
await textarea.press("Enter");
|
||||
|
||||
await uploadStartedPromise;
|
||||
await expect(promptForm.getByText("report.docx")).toBeVisible();
|
||||
|
||||
releaseUpload();
|
||||
await expect(page.getByText("Hello from DeerFlow!")).toBeVisible({
|
||||
timeout: 10_000,
|
||||
});
|
||||
await expect(promptForm.getByText("report.docx")).toBeHidden();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -63,37 +63,3 @@ test("aggregates token usage messages once per assistant turn", () => {
|
||||
),
|
||||
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
|
||||
});
|
||||
|
||||
test("hides internal todo reminder messages from message groups", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "Audit the middleware",
|
||||
},
|
||||
{
|
||||
id: "todo-reminder-1",
|
||||
type: "human",
|
||||
name: "todo_completion_reminder",
|
||||
content: "<system_reminder>finish todos</system_reminder>",
|
||||
},
|
||||
{
|
||||
id: "todo-reminder-2",
|
||||
type: "human",
|
||||
name: "todo_reminder",
|
||||
content: "<system_reminder>remember todos</system_reminder>",
|
||||
},
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "Done",
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
const groups = getMessageGroups(messages);
|
||||
|
||||
expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]);
|
||||
expect(
|
||||
groups.flatMap((group) => group.messages).map((message) => message.id),
|
||||
).toEqual(["human-1", "ai-1"]);
|
||||
});
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { expect, test } from "vitest";
|
||||
|
||||
import { mergeMessages } from "@/core/threads/hooks";
|
||||
|
||||
test("mergeMessages removes duplicate messages already present in history", () => {
|
||||
const human = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "Design an agent",
|
||||
} as Message;
|
||||
const ai = {
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "Let's design it.",
|
||||
} as Message;
|
||||
|
||||
expect(mergeMessages([human, ai, human, ai], [], [])).toEqual([human, ai]);
|
||||
});
|
||||
|
||||
test("mergeMessages lets live thread messages replace overlapping history", () => {
|
||||
const oldHuman = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "old",
|
||||
} as Message;
|
||||
const liveHuman = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "live",
|
||||
} as Message;
|
||||
const oldAi = {
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "old",
|
||||
} as Message;
|
||||
const liveAi = {
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "live",
|
||||
} as Message;
|
||||
|
||||
expect(mergeMessages([oldHuman, oldAi], [liveHuman, liveAi], [])).toEqual([
|
||||
liveHuman,
|
||||
liveAi,
|
||||
]);
|
||||
});
|
||||
|
||||
test("mergeMessages deduplicates tool messages by tool_call_id", () => {
|
||||
const oldTool = {
|
||||
id: "tool-message-old",
|
||||
type: "tool",
|
||||
tool_call_id: "call-1",
|
||||
content: "old",
|
||||
} as Message;
|
||||
const liveTool = {
|
||||
id: "tool-message-live",
|
||||
type: "tool",
|
||||
tool_call_id: "call-1",
|
||||
content: "live",
|
||||
} as Message;
|
||||
|
||||
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
|
||||
});
|
||||
@@ -72,7 +72,6 @@ def find_config_file() -> Path | None:
|
||||
|
||||
|
||||
_SECTION_RE = re.compile(r"^([A-Za-z_][\w-]*)\s*:\s*$")
|
||||
_INDENTED_SECTION_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*$")
|
||||
_KEY_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*(\S.*?)\s*$")
|
||||
|
||||
|
||||
@@ -142,84 +141,6 @@ def section_value(lines: list[str], section: str, key: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def nested_section_value(lines: list[str], section_path: str, key: str) -> str | None:
|
||||
"""Return the value of a nested YAML key like ``channels.discord.enabled``.
|
||||
|
||||
Handles two levels of nesting:
|
||||
channels:
|
||||
discord:
|
||||
enabled: true
|
||||
"""
|
||||
parts = section_path.split(".")
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
parent_section, child_section = parts
|
||||
|
||||
inside_parent = False
|
||||
inside_child = False
|
||||
parent_indent: int | None = None
|
||||
child_indent: int | None = None
|
||||
|
||||
for raw in lines:
|
||||
line = _strip_comment(raw)
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
stripped = line.lstrip()
|
||||
indent = len(line) - len(stripped)
|
||||
|
||||
# Top-level section match
|
||||
sect_match = _SECTION_RE.match(line)
|
||||
if sect_match:
|
||||
if indent == 0:
|
||||
inside_parent = sect_match.group(1) == parent_section
|
||||
inside_child = False
|
||||
parent_indent = None
|
||||
child_indent = None
|
||||
continue
|
||||
|
||||
if not inside_parent:
|
||||
continue
|
||||
|
||||
# Track parent indent from first child
|
||||
if parent_indent is None and indent > 0:
|
||||
parent_indent = indent
|
||||
|
||||
# If indent goes back to 0, we left the parent section
|
||||
if indent == 0:
|
||||
inside_parent = False
|
||||
inside_child = False
|
||||
continue
|
||||
|
||||
# Check if we're at the parent's child level (subsection)
|
||||
if parent_indent is not None and indent == parent_indent:
|
||||
# This could be a subsection or a direct key of parent
|
||||
sub_match = _INDENTED_SECTION_RE.match(line)
|
||||
if sub_match and sub_match.group(1) == child_section:
|
||||
inside_child = True
|
||||
child_indent = None
|
||||
continue
|
||||
else:
|
||||
inside_child = False
|
||||
continue
|
||||
|
||||
if not inside_child:
|
||||
continue
|
||||
|
||||
# We're inside the subsection — track child indent
|
||||
if child_indent is None and indent > (parent_indent or 0):
|
||||
child_indent = indent
|
||||
|
||||
if child_indent is not None and indent != child_indent:
|
||||
continue
|
||||
|
||||
key_match = _KEY_RE.match(line)
|
||||
if key_match and key_match.group(1) == key:
|
||||
return _unquote(key_match.group(2).strip())
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def detect_from_config(path: Path) -> list[str]:
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8", errors="replace")
|
||||
@@ -231,8 +152,6 @@ def detect_from_config(path: Path) -> list[str]:
|
||||
extras.add("postgres")
|
||||
if (section_value(lines, "checkpointer", "type") or "").lower() == "postgres":
|
||||
extras.add("postgres")
|
||||
if (nested_section_value(lines, "channels.discord", "enabled") or "").lower() == "true":
|
||||
extras.add("discord")
|
||||
return sorted(extras)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user