Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7752e74e2b | |||
| ba99a23814 | |||
| 6d611c2bf6 | |||
| 6d3cffb4f0 | |||
| 48e038f752 | |||
| 7c42ab3e16 | |||
| 7a2670eaea | |||
| 0c37509b38 | |||
| 181d836541 | |||
| 2b2742c034 | |||
| 6ffe267d20 | |||
| c995c3a394 |
+291
-11
@@ -3,8 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.channels.base import Channel
|
||||
@@ -21,6 +23,12 @@ class DiscordChannel(Channel):
|
||||
Configuration keys (in ``config.yaml`` under ``channels.discord``):
|
||||
- ``bot_token``: Discord Bot token.
|
||||
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
|
||||
- ``mention_only``: (optional) If true, only respond when the bot is mentioned.
|
||||
- ``allowed_channels``: (optional) List of channel IDs where messages are always accepted
|
||||
(even when mention_only is true). Use for channels where you want the bot to respond
|
||||
without mentions. Empty = mention_only applies everywhere.
|
||||
- ``thread_mode``: (optional) If true, group a channel conversation into a thread.
|
||||
Default: same as ``mention_only``.
|
||||
"""
|
||||
|
||||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
||||
@@ -32,6 +40,29 @@ class DiscordChannel(Channel):
|
||||
self._allowed_guilds.add(int(guild_id))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
self._mention_only: bool = bool(config.get("mention_only", False))
|
||||
self._thread_mode: bool = config.get("thread_mode", self._mention_only)
|
||||
self._allowed_channels: set[str] = set()
|
||||
for channel_id in config.get("allowed_channels", []):
|
||||
self._allowed_channels.add(str(channel_id))
|
||||
|
||||
# Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON).
|
||||
# Uses a dedicated JSON file separate from ChannelStore, which maps IM
|
||||
# conversations to DeerFlow thread IDs — a different concern.
|
||||
self._active_threads: dict[str, str] = {}
|
||||
# Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()).
|
||||
self._active_thread_ids: set[str] = set()
|
||||
# Lock protecting _active_threads and the JSON file from concurrent access.
|
||||
# _run_client (Discord loop thread) and the main thread both read/write.
|
||||
self._thread_store_lock = threading.Lock()
|
||||
store = config.get("channel_store")
|
||||
if store is not None:
|
||||
self._thread_store_path = store._path.parent / "discord_threads.json"
|
||||
else:
|
||||
self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json"
|
||||
|
||||
# Typing indicator management
|
||||
self._typing_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
self._client = None
|
||||
self._thread: threading.Thread | None = None
|
||||
@@ -75,12 +106,56 @@ class DiscordChannel(Channel):
|
||||
|
||||
self._thread = threading.Thread(target=self._run_client, daemon=True)
|
||||
self._thread.start()
|
||||
self._load_active_threads()
|
||||
logger.info("Discord channel started")
|
||||
|
||||
def _load_active_threads(self) -> None:
|
||||
"""Restore Discord thread mappings from the dedicated JSON file on startup."""
|
||||
with self._thread_store_lock:
|
||||
try:
|
||||
if not self._thread_store_path.exists():
|
||||
logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path)
|
||||
return
|
||||
data = json.loads(self._thread_store_path.read_text())
|
||||
self._active_threads.clear()
|
||||
self._active_thread_ids.clear()
|
||||
for channel_id, thread_id in data.items():
|
||||
self._active_threads[channel_id] = thread_id
|
||||
self._active_thread_ids.add(thread_id)
|
||||
if self._active_threads:
|
||||
logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path)
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to load thread mappings")
|
||||
|
||||
def _save_thread(self, channel_id: str, thread_id: str) -> None:
|
||||
"""Persist a Discord thread mapping to the dedicated JSON file."""
|
||||
with self._thread_store_lock:
|
||||
try:
|
||||
data: dict[str, str] = {}
|
||||
if self._thread_store_path.exists():
|
||||
data = json.loads(self._thread_store_path.read_text())
|
||||
old_id = data.get(channel_id)
|
||||
data[channel_id] = thread_id
|
||||
# Update reverse-lookup set
|
||||
if old_id:
|
||||
self._active_thread_ids.discard(old_id)
|
||||
self._active_thread_ids.add(thread_id)
|
||||
self._thread_store_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._thread_store_path.write_text(json.dumps(data, indent=2))
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id)
|
||||
|
||||
async def stop(self) -> None:
|
||||
self._running = False
|
||||
self.bus.unsubscribe_outbound(self._on_outbound)
|
||||
|
||||
# Cancel all active typing indicator tasks
|
||||
for target_id, task in list(self._typing_tasks.items()):
|
||||
if not task.done():
|
||||
task.cancel()
|
||||
logger.debug("[Discord] cancelled typing task for target %s", target_id)
|
||||
self._typing_tasks.clear()
|
||||
|
||||
if self._client and self._discord_loop and self._discord_loop.is_running():
|
||||
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
|
||||
try:
|
||||
@@ -100,6 +175,10 @@ class DiscordChannel(Channel):
|
||||
logger.info("Discord channel stopped")
|
||||
|
||||
async def send(self, msg: OutboundMessage) -> None:
|
||||
# Stop typing indicator once we're sending the response
|
||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
||||
await asyncio.wrap_future(stop_future)
|
||||
|
||||
target = await self._resolve_target(msg)
|
||||
if target is None:
|
||||
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
||||
@@ -111,6 +190,9 @@ class DiscordChannel(Channel):
|
||||
await asyncio.wrap_future(send_future)
|
||||
|
||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
||||
await asyncio.wrap_future(stop_future)
|
||||
|
||||
target = await self._resolve_target(msg)
|
||||
if target is None:
|
||||
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
||||
@@ -130,6 +212,41 @@ class DiscordChannel(Channel):
|
||||
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
|
||||
return False
|
||||
|
||||
async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None:
|
||||
"""Starts a loop to send periodic typing indicators."""
|
||||
target_id = thread_ts or chat_id
|
||||
if target_id in self._typing_tasks:
|
||||
return # Already typing for this target
|
||||
|
||||
async def _typing_loop():
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
await channel.trigger_typing()
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(10)
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
task = asyncio.create_task(_typing_loop())
|
||||
self._typing_tasks[target_id] = task
|
||||
|
||||
async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None:
|
||||
"""Stops the typing loop for a specific target."""
|
||||
target_id = thread_ts or chat_id
|
||||
task = self._typing_tasks.pop(target_id, None)
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
logger.debug("[Discord] stopped typing indicator for target %s", target_id)
|
||||
|
||||
async def _add_reaction(self, message) -> None:
|
||||
"""Add a checkmark reaction to acknowledge the message was received."""
|
||||
try:
|
||||
await message.add_reaction("✅")
|
||||
except Exception:
|
||||
logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True)
|
||||
|
||||
async def _on_message(self, message) -> None:
|
||||
if not self._running or not self._client:
|
||||
return
|
||||
@@ -152,15 +269,143 @@ class DiscordChannel(Channel):
|
||||
if self._discord_module is None:
|
||||
return
|
||||
|
||||
if isinstance(message.channel, self._discord_module.Thread):
|
||||
chat_id = str(message.channel.parent_id or message.channel.id)
|
||||
thread_id = str(message.channel.id)
|
||||
# Determine whether the bot is mentioned in this message
|
||||
user = self._client.user if self._client else None
|
||||
if user:
|
||||
bot_mention = user.mention # <@ID>
|
||||
alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant)
|
||||
standard_mention = f"<@{user.id}>"
|
||||
else:
|
||||
thread = await self._create_thread(message)
|
||||
if thread is None:
|
||||
bot_mention = None
|
||||
alt_mention = None
|
||||
standard_mention = ""
|
||||
has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content)
|
||||
|
||||
# Strip mention from text for processing
|
||||
if has_mention:
|
||||
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
||||
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
||||
|
||||
# --- Determine thread/channel routing and typing target ---
|
||||
thread_id = None
|
||||
chat_id = None
|
||||
typing_target = None # The Discord object to type into
|
||||
|
||||
if isinstance(message.channel, self._discord_module.Thread):
|
||||
# --- Message already inside a thread ---
|
||||
thread_obj = message.channel
|
||||
thread_id = str(thread_obj.id)
|
||||
chat_id = str(thread_obj.parent_id or thread_obj.id)
|
||||
typing_target = thread_obj
|
||||
|
||||
# If this is a known active thread, process normally
|
||||
if thread_id in self._active_thread_ids:
|
||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
chat_id=chat_id,
|
||||
user_id=str(message.author.id),
|
||||
text=text,
|
||||
msg_type=msg_type,
|
||||
thread_ts=thread_id,
|
||||
metadata={
|
||||
"guild_id": str(guild.id) if guild else None,
|
||||
"channel_id": str(message.channel.id),
|
||||
"message_id": str(message.id),
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
self._publish(inbound)
|
||||
# Start typing indicator in the thread
|
||||
if typing_target:
|
||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
||||
asyncio.create_task(self._add_reaction(message))
|
||||
return
|
||||
chat_id = str(message.channel.id)
|
||||
thread_id = str(thread.id)
|
||||
|
||||
# Thread not tracked (orphaned) — create new thread and handle below
|
||||
logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id)
|
||||
thread_id = None
|
||||
typing_target = None
|
||||
|
||||
# At this point we're guaranteed to be in a channel, not a thread
|
||||
# (the Thread case is handled above). Apply mention_only for all
|
||||
# non-thread messages — no special case needed.
|
||||
channel_id = str(message.channel.id)
|
||||
|
||||
# Check if there's an active thread for this channel
|
||||
if channel_id in self._active_threads:
|
||||
# respect mention_only: if enabled, only process messages that mention the bot
|
||||
# (unless the channel is in allowed_channels)
|
||||
# Messages within a thread are always allowed through (continuation).
|
||||
# At this code point we know the message is in a channel, not a thread
|
||||
# (Thread case handled above), so always apply the check.
|
||||
if self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
||||
logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id)
|
||||
return
|
||||
# mention_only + fresh @ → create new thread instead of routing to existing one
|
||||
if self._mention_only and has_mention:
|
||||
thread_obj = await self._create_thread(message)
|
||||
if thread_obj is not None:
|
||||
target_thread_id = str(thread_obj.id)
|
||||
self._active_threads[channel_id] = target_thread_id
|
||||
self._save_thread(channel_id, target_thread_id)
|
||||
thread_id = target_thread_id
|
||||
chat_id = channel_id
|
||||
typing_target = thread_obj
|
||||
logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id)
|
||||
else:
|
||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
||||
thread_id = channel_id
|
||||
chat_id = channel_id
|
||||
typing_target = message.channel
|
||||
else:
|
||||
# Existing session → route to the existing thread
|
||||
target_thread_id = self._active_threads[channel_id]
|
||||
logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id)
|
||||
thread_id = target_thread_id
|
||||
chat_id = channel_id
|
||||
typing_target = await self._get_channel_or_thread(target_thread_id)
|
||||
elif self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
||||
# Not mentioned and not in an allowed channel → skip
|
||||
logger.debug("[Discord] skipping message without mention in channel %s", channel_id)
|
||||
return
|
||||
elif self._mention_only and has_mention:
|
||||
# First mention in this channel → create thread
|
||||
thread_obj = await self._create_thread(message)
|
||||
if thread_obj is not None:
|
||||
target_thread_id = str(thread_obj.id)
|
||||
self._active_threads[channel_id] = target_thread_id
|
||||
self._save_thread(channel_id, target_thread_id)
|
||||
thread_id = target_thread_id
|
||||
chat_id = channel_id
|
||||
typing_target = thread_obj # Type into the new thread
|
||||
logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name)
|
||||
else:
|
||||
# Fallback: thread creation failed (disabled/permissions), reply in channel
|
||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
||||
thread_id = channel_id
|
||||
chat_id = channel_id
|
||||
typing_target = message.channel # Type into the channel
|
||||
elif self._thread_mode:
|
||||
# thread_mode but mention_only is False → create thread anyway for conversation grouping
|
||||
thread_obj = await self._create_thread(message)
|
||||
if thread_obj is None:
|
||||
# Thread creation failed (disabled/permissions), fall back to channel replies
|
||||
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
||||
thread_id = channel_id
|
||||
chat_id = channel_id
|
||||
typing_target = message.channel # Type into the channel
|
||||
else:
|
||||
target_thread_id = str(thread_obj.id)
|
||||
self._active_threads[channel_id] = target_thread_id
|
||||
self._save_thread(channel_id, target_thread_id)
|
||||
thread_id = target_thread_id
|
||||
chat_id = channel_id
|
||||
typing_target = thread_obj # Type into the new thread
|
||||
else:
|
||||
# No threading — reply directly in channel
|
||||
thread_id = channel_id
|
||||
chat_id = channel_id
|
||||
typing_target = message.channel # Type into the channel
|
||||
|
||||
msg_type = InboundMessageType.COMMAND if text.startswith("/") else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
@@ -177,6 +422,15 @@ class DiscordChannel(Channel):
|
||||
)
|
||||
inbound.topic_id = thread_id
|
||||
|
||||
# Start typing indicator in the correct target (thread or channel)
|
||||
if typing_target:
|
||||
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
||||
|
||||
self._publish(inbound)
|
||||
asyncio.create_task(self._add_reaction(message))
|
||||
|
||||
def _publish(self, inbound) -> None:
|
||||
"""Publish an inbound message to the main event loop."""
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
||||
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
||||
@@ -198,14 +452,40 @@ class DiscordChannel(Channel):
|
||||
|
||||
async def _create_thread(self, message):
|
||||
try:
|
||||
if self._discord_module is None:
|
||||
return None
|
||||
|
||||
# Only TextChannel (type 0) and NewsChannel (type 10) support threads
|
||||
channel_type = message.channel.type
|
||||
if channel_type not in (
|
||||
self._discord_module.ChannelType.text,
|
||||
self._discord_module.ChannelType.news,
|
||||
):
|
||||
logger.info(
|
||||
"[Discord] channel type %s (%s) does not support threads",
|
||||
channel_type.value,
|
||||
channel_type.name,
|
||||
)
|
||||
return None
|
||||
|
||||
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
|
||||
return await message.create_thread(name=thread_name)
|
||||
except self._discord_module.errors.HTTPException as exc:
|
||||
if exc.code == 50024:
|
||||
logger.info(
|
||||
"[Discord] cannot create thread in channel %s (error code 50024): %s",
|
||||
message.channel.id,
|
||||
channel_type.name if (channel_type := message.channel.type) else "unknown",
|
||||
)
|
||||
else:
|
||||
logger.exception(
|
||||
"[Discord] failed to create thread for message=%s (HTTPException %s)",
|
||||
message.id,
|
||||
exc.code,
|
||||
)
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
|
||||
try:
|
||||
await message.channel.send("Could not create a thread for your message. Please check that threads are enabled in this channel.")
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def _resolve_target(self, msg: OutboundMessage):
|
||||
|
||||
@@ -787,13 +787,22 @@ class ChannelManager:
|
||||
return
|
||||
|
||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||
result = await client.runs.wait(
|
||||
thread_id,
|
||||
assistant_id,
|
||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||
config=run_config,
|
||||
context=run_context,
|
||||
)
|
||||
try:
|
||||
result = await client.runs.wait(
|
||||
thread_id,
|
||||
assistant_id,
|
||||
input={"messages": [{"role": "human", "content": msg.text}]},
|
||||
config=run_config,
|
||||
context=run_context,
|
||||
multitask_strategy="reject",
|
||||
)
|
||||
except Exception as exc:
|
||||
if _is_thread_busy_error(exc):
|
||||
logger.warning("[Manager] thread busy (concurrent run rejected): thread_id=%s", thread_id)
|
||||
await self._send_error(msg, THREAD_BUSY_MESSAGE)
|
||||
return
|
||||
else:
|
||||
raise
|
||||
|
||||
response_text = _extract_response_text(result)
|
||||
artifacts = _extract_artifacts(result)
|
||||
|
||||
@@ -167,6 +167,8 @@ class ChannelService:
|
||||
return False
|
||||
|
||||
try:
|
||||
config = dict(config)
|
||||
config["channel_store"] = self.store
|
||||
channel = channel_cls(bus=self.bus, config=config)
|
||||
self._channels[name] = channel
|
||||
await channel.start()
|
||||
|
||||
@@ -8,6 +8,8 @@ from pydantic import BaseModel, Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SECRET_FILE = ".jwt_secret"
|
||||
|
||||
|
||||
class AuthConfig(BaseModel):
|
||||
"""JWT and auth-related configuration. Parsed once at startup.
|
||||
@@ -30,6 +32,32 @@ class AuthConfig(BaseModel):
|
||||
_auth_config: AuthConfig | None = None
|
||||
|
||||
|
||||
def _load_or_create_secret() -> str:
|
||||
"""Load persisted JWT secret from ``{base_dir}/.jwt_secret``, or generate and persist a new one."""
|
||||
from deerflow.config.paths import get_paths
|
||||
|
||||
paths = get_paths()
|
||||
secret_file = paths.base_dir / _SECRET_FILE
|
||||
|
||||
try:
|
||||
if secret_file.exists():
|
||||
secret = secret_file.read_text(encoding="utf-8").strip()
|
||||
if secret:
|
||||
return secret
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Failed to read JWT secret from {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can read its persisted auth secret.") from exc
|
||||
|
||||
secret = secrets.token_urlsafe(32)
|
||||
try:
|
||||
secret_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
fd = os.open(secret_file, os.O_WRONLY | os.O_CREAT | os.O_TRUNC, 0o600)
|
||||
with os.fdopen(fd, "w", encoding="utf-8") as fh:
|
||||
fh.write(secret)
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Failed to persist JWT secret to {secret_file}. Set AUTH_JWT_SECRET explicitly or fix DEER_FLOW_HOME/base directory permissions so DeerFlow can store a stable auth secret.") from exc
|
||||
return secret
|
||||
|
||||
|
||||
def get_auth_config() -> AuthConfig:
|
||||
"""Get the global AuthConfig instance. Parses from env on first call."""
|
||||
global _auth_config
|
||||
@@ -39,11 +67,11 @@ def get_auth_config() -> AuthConfig:
|
||||
load_dotenv()
|
||||
jwt_secret = os.environ.get("AUTH_JWT_SECRET")
|
||||
if not jwt_secret:
|
||||
jwt_secret = secrets.token_urlsafe(32)
|
||||
jwt_secret = _load_or_create_secret()
|
||||
os.environ["AUTH_JWT_SECRET"] = jwt_secret
|
||||
logger.warning(
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated ephemeral secret. "
|
||||
"Sessions will be invalidated on restart. "
|
||||
"⚠ AUTH_JWT_SECRET is not set — using an auto-generated secret "
|
||||
"persisted to .jwt_secret. Sessions will survive restarts. "
|
||||
"For production, add AUTH_JWT_SECRET to your .env file: "
|
||||
'python -c "import secrets; print(secrets.token_urlsafe(32))"'
|
||||
)
|
||||
|
||||
@@ -20,6 +20,9 @@ ACTIVE_CONTENT_MIME_TYPES = {
|
||||
"image/svg+xml",
|
||||
}
|
||||
|
||||
MAX_SKILL_ARCHIVE_MEMBER_BYTES = 16 * 1024 * 1024
|
||||
_SKILL_ARCHIVE_READ_CHUNK_SIZE = 64 * 1024
|
||||
|
||||
|
||||
def _build_content_disposition(disposition_type: str, filename: str) -> str:
|
||||
"""Build an RFC 5987 encoded Content-Disposition header value."""
|
||||
@@ -44,6 +47,22 @@ def is_text_file_by_content(path: Path, sample_size: int = 8192) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def _read_skill_archive_member(zip_ref: zipfile.ZipFile, info: zipfile.ZipInfo) -> bytes:
|
||||
"""Read a .skill archive member while enforcing an uncompressed size cap."""
|
||||
if info.file_size > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
||||
|
||||
chunks: list[bytes] = []
|
||||
total_read = 0
|
||||
with zip_ref.open(info, "r") as src:
|
||||
while chunk := src.read(_SKILL_ARCHIVE_READ_CHUNK_SIZE):
|
||||
total_read += len(chunk)
|
||||
if total_read > MAX_SKILL_ARCHIVE_MEMBER_BYTES:
|
||||
raise HTTPException(status_code=413, detail="Skill archive member is too large to preview")
|
||||
chunks.append(chunk)
|
||||
return b"".join(chunks)
|
||||
|
||||
|
||||
def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> bytes | None:
|
||||
"""Extract a file from a .skill ZIP archive.
|
||||
|
||||
@@ -60,16 +79,16 @@ def _extract_file_from_skill_archive(zip_path: Path, internal_path: str) -> byte
|
||||
try:
|
||||
with zipfile.ZipFile(zip_path, "r") as zip_ref:
|
||||
# List all files in the archive
|
||||
namelist = zip_ref.namelist()
|
||||
infos_by_name = {info.filename: info for info in zip_ref.infolist()}
|
||||
|
||||
# Try direct path first
|
||||
if internal_path in namelist:
|
||||
return zip_ref.read(internal_path)
|
||||
if internal_path in infos_by_name:
|
||||
return _read_skill_archive_member(zip_ref, infos_by_name[internal_path])
|
||||
|
||||
# Try with any top-level directory prefix (e.g., "skill-name/SKILL.md")
|
||||
for name in namelist:
|
||||
for name, info in infos_by_name.items():
|
||||
if name.endswith("/" + internal_path) or name == internal_path:
|
||||
return zip_ref.read(name)
|
||||
return _read_skill_archive_member(zip_ref, info)
|
||||
|
||||
# Not found
|
||||
return None
|
||||
|
||||
@@ -99,7 +99,7 @@ rm -f backend/.deer-flow/data/deerflow.db
|
||||
| `.deer-flow/users/{user_id}/memory.json` | 用户级 memory |
|
||||
| `.deer-flow/users/{user_id}/agents/{agent_name}/` | 用户自定义 agent 配置、SOUL 和 agent memory |
|
||||
| `.deer-flow/admin_initial_credentials.txt` | `reset_admin` 生成的新凭据文件(0600,读完应删除) |
|
||||
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成临时密钥,重启后 session 失效) |
|
||||
| `.env` 中的 `AUTH_JWT_SECRET` | JWT 签名密钥(未设置时自动生成并持久化到 `.deer-flow/.jwt_secret`,重启后 session 保持) |
|
||||
|
||||
### 生产环境建议
|
||||
|
||||
@@ -137,4 +137,4 @@ python -c "import secrets; print(secrets.token_urlsafe(32))"
|
||||
| 启动后没看到密码 | 当前实现不在启动日志输出密码 | 首次安装访问 `/setup`;忘记密码用 `reset_admin` |
|
||||
| `/login` 自动跳到 `/setup` | 系统还没有 admin | 在 `/setup` 创建第一个 admin |
|
||||
| 登录后 POST 返回 403 | CSRF token 缺失 | 确认前端已更新 |
|
||||
| 重启后需要重新登录 | `AUTH_JWT_SECRET` 未持久化 | 在 `.env` 中设置固定密钥 |
|
||||
| 重启后需要重新登录 | `.jwt_secret` 文件被删除且 `.env` 未设置 `AUTH_JWT_SECRET` | 在 `.env` 中设置固定密钥 |
|
||||
|
||||
+27
-22
@@ -104,45 +104,46 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
return "[Tool call was interrupted and did not return a result.]"
|
||||
|
||||
def _build_patched_messages(self, messages: list) -> list | None:
|
||||
"""Return a new message list with patches inserted at the correct positions.
|
||||
"""Return messages with tool results grouped after their tool-call AIMessage.
|
||||
|
||||
For each AIMessage with dangling tool_calls (no corresponding ToolMessage),
|
||||
a synthetic ToolMessage is inserted immediately after that AIMessage.
|
||||
Returns None if no patches are needed.
|
||||
This normalizes model-bound causal order before provider serialization while
|
||||
preserving already-valid transcripts unchanged.
|
||||
"""
|
||||
# Collect IDs of all existing ToolMessages
|
||||
existing_tool_msg_ids: set[str] = set()
|
||||
tool_messages_by_id: dict[str, ToolMessage] = {}
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage):
|
||||
existing_tool_msg_ids.add(msg.tool_call_id)
|
||||
tool_messages_by_id.setdefault(msg.tool_call_id, msg)
|
||||
|
||||
# Check if any patching is needed
|
||||
needs_patch = False
|
||||
tool_call_ids: set[str] = set()
|
||||
for msg in messages:
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids:
|
||||
needs_patch = True
|
||||
break
|
||||
if needs_patch:
|
||||
break
|
||||
if tc_id:
|
||||
tool_call_ids.add(tc_id)
|
||||
|
||||
if not needs_patch:
|
||||
return None
|
||||
|
||||
# Build new list with patches inserted right after each dangling AIMessage
|
||||
patched: list = []
|
||||
patched_ids: set[str] = set()
|
||||
consumed_tool_msg_ids: set[str] = set()
|
||||
patch_count = 0
|
||||
for msg in messages:
|
||||
if isinstance(msg, ToolMessage) and msg.tool_call_id in tool_call_ids:
|
||||
continue
|
||||
|
||||
patched.append(msg)
|
||||
if getattr(msg, "type", None) != "ai":
|
||||
continue
|
||||
|
||||
for tc in self._message_tool_calls(msg):
|
||||
tc_id = tc.get("id")
|
||||
if tc_id and tc_id not in existing_tool_msg_ids and tc_id not in patched_ids:
|
||||
if not tc_id or tc_id in consumed_tool_msg_ids:
|
||||
continue
|
||||
|
||||
existing_tool_msg = tool_messages_by_id.get(tc_id)
|
||||
if existing_tool_msg is not None:
|
||||
patched.append(existing_tool_msg)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
else:
|
||||
patched.append(
|
||||
ToolMessage(
|
||||
content=self._synthetic_tool_message_content(tc),
|
||||
@@ -151,10 +152,14 @@ class DanglingToolCallMiddleware(AgentMiddleware[AgentState]):
|
||||
status="error",
|
||||
)
|
||||
)
|
||||
patched_ids.add(tc_id)
|
||||
consumed_tool_msg_ids.add(tc_id)
|
||||
patch_count += 1
|
||||
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
if patched == messages:
|
||||
return None
|
||||
|
||||
if patch_count:
|
||||
logger.warning(f"Injecting {patch_count} placeholder ToolMessage(s) for dangling tool calls")
|
||||
return patched
|
||||
|
||||
@override
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Any, Protocol, override, runtime_checkable
|
||||
from langchain.agents import AgentState
|
||||
from langchain.agents.middleware import SummarizationMiddleware
|
||||
from langchain_core.messages import AIMessage, AnyMessage, HumanMessage, RemoveMessage, ToolMessage
|
||||
from langchain_core.messages.utils import get_buffer_string
|
||||
from langgraph.config import get_config
|
||||
from langgraph.graph.message import REMOVE_ALL_MESSAGES
|
||||
from langgraph.runtime import Runtime
|
||||
@@ -175,12 +176,84 @@ class DeerFlowSummarizationMiddleware(SummarizationMiddleware):
|
||||
]
|
||||
}
|
||||
|
||||
@override
|
||||
def _create_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary without emitting streaming events to the client.
|
||||
|
||||
Suppresses callbacks to prevent the internal summarization LLM call from
|
||||
producing visible AI message chunks in the frontend's ``messages-tuple``
|
||||
stream (issue #2804).
|
||||
"""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
formatted = get_buffer_string(trimmed)
|
||||
|
||||
try:
|
||||
response = self.model.with_config(callbacks=[]).invoke(
|
||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
||||
config={
|
||||
"metadata": {"lc_source": "summarization"},
|
||||
"callbacks": [],
|
||||
},
|
||||
)
|
||||
return self._extract_summary_text(response)
|
||||
except Exception as e:
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
@override
|
||||
async def _acreate_summary(self, messages_to_summarize: list[AnyMessage]) -> str:
|
||||
"""Generate summary without emitting streaming events to the client.
|
||||
|
||||
Suppresses callbacks to prevent the internal summarization LLM call from
|
||||
producing visible AI message chunks in the frontend's ``messages-tuple``
|
||||
stream (issue #2804).
|
||||
"""
|
||||
if not messages_to_summarize:
|
||||
return "No previous conversation history."
|
||||
|
||||
trimmed = self._trim_messages_for_summary(messages_to_summarize)
|
||||
if not trimmed:
|
||||
return "Previous conversation was too long to summarize."
|
||||
|
||||
formatted = get_buffer_string(trimmed)
|
||||
|
||||
try:
|
||||
response = await self.model.with_config(callbacks=[]).ainvoke(
|
||||
self.summary_prompt.format(messages=formatted).rstrip(),
|
||||
config={
|
||||
"metadata": {"lc_source": "summarization"},
|
||||
"callbacks": [],
|
||||
},
|
||||
)
|
||||
return self._extract_summary_text(response)
|
||||
except Exception as e:
|
||||
return f"Error generating summary: {e!s}"
|
||||
|
||||
def _extract_summary_text(self, response: Any) -> str:
|
||||
# Prefer .text which normalizes list content blocks (e.g. [{"type": "text", "text": "..."}]).
|
||||
# Fall back to .content for non-LangChain responses.
|
||||
summary_text = getattr(response, "text", None)
|
||||
if summary_text is None:
|
||||
summary_text = getattr(response, "content", "")
|
||||
return summary_text.strip() if isinstance(summary_text, str) else str(summary_text).strip()
|
||||
|
||||
@override
|
||||
def _build_new_messages(self, summary: str) -> list[HumanMessage]:
|
||||
"""Override the base implementation to let the human message with the special name 'summary'.
|
||||
And this message will be ignored to display in the frontend, but still can be used as context for the model.
|
||||
"""
|
||||
return [HumanMessage(content=f"Here is a summary of the conversation to date:\n\n{summary}", name="summary")]
|
||||
return [
|
||||
HumanMessage(
|
||||
content=f"Here is a summary of the conversation to date:\n\n{summary}",
|
||||
name="summary",
|
||||
additional_kwargs={"hide_from_ui": True},
|
||||
)
|
||||
]
|
||||
|
||||
def _preserve_dynamic_context_reminders(
|
||||
self,
|
||||
|
||||
@@ -7,17 +7,21 @@ reminder message so the model still knows about the outstanding todo list.
|
||||
|
||||
Additionally, this middleware prevents the agent from exiting the loop while
|
||||
there are still incomplete todo items. When the model produces a final response
|
||||
(no tool calls) but todos are not yet complete, the middleware injects a reminder
|
||||
and jumps back to the model node to force continued engagement.
|
||||
(no tool calls) but todos are not yet complete, the middleware queues a reminder
|
||||
for the next model request and jumps back to the model node to force continued
|
||||
engagement. The completion reminder is injected via ``wrap_model_call`` instead
|
||||
of being persisted into graph state as a normal user-visible message.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import threading
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, override
|
||||
|
||||
from langchain.agents.middleware import TodoListMiddleware
|
||||
from langchain.agents.middleware.todo import PlanningState, Todo
|
||||
from langchain.agents.middleware.types import hook_config
|
||||
from langchain.agents.middleware.types import ModelCallResult, ModelRequest, ModelResponse, hook_config
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from langgraph.runtime import Runtime
|
||||
|
||||
@@ -55,6 +59,51 @@ def _format_todos(todos: list[Todo]) -> str:
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _format_completion_reminder(todos: list[Todo]) -> str:
|
||||
"""Format a completion reminder for incomplete todo items."""
|
||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
||||
return (
|
||||
"<system_reminder>\n"
|
||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
||||
f"{incomplete_text}\n\n"
|
||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
||||
"as you finish them, and only respond when all items are done.\n"
|
||||
"</system_reminder>"
|
||||
)
|
||||
|
||||
|
||||
_TOOL_CALL_FINISH_REASONS = {"tool_calls", "function_call"}
|
||||
|
||||
|
||||
def _has_tool_call_intent_or_error(message: AIMessage) -> bool:
|
||||
"""Return True when an AIMessage is not a clean final answer.
|
||||
|
||||
Todo completion reminders should only fire when the model has produced a
|
||||
plain final response. Provider/tool parsing details have moved across
|
||||
LangChain versions and integrations, so keep all tool-intent/error signals
|
||||
behind this helper instead of checking one concrete field at the call site.
|
||||
"""
|
||||
if message.tool_calls:
|
||||
return True
|
||||
|
||||
if getattr(message, "invalid_tool_calls", None):
|
||||
return True
|
||||
|
||||
# Backward/provider compatibility: some integrations preserve raw or legacy
|
||||
# tool-call intent in additional_kwargs even when structured tool_calls is
|
||||
# empty. If this helper changes, update the matching sentinel test
|
||||
# `TestToolCallIntentOrError.test_langchain_ai_message_tool_fields_are_explicitly_handled`;
|
||||
# if that test fails after a LangChain upgrade, review this helper so new
|
||||
# tool-call/error fields are not silently treated as clean final answers.
|
||||
additional_kwargs = getattr(message, "additional_kwargs", {}) or {}
|
||||
if additional_kwargs.get("tool_calls") or additional_kwargs.get("function_call"):
|
||||
return True
|
||||
|
||||
response_metadata = getattr(message, "response_metadata", {}) or {}
|
||||
return response_metadata.get("finish_reason") in _TOOL_CALL_FINISH_REASONS
|
||||
|
||||
|
||||
class TodoMiddleware(TodoListMiddleware):
|
||||
"""Extends TodoListMiddleware with `write_todos` context-loss detection.
|
||||
|
||||
@@ -89,6 +138,7 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
formatted = _format_todos(todos)
|
||||
reminder = HumanMessage(
|
||||
name="todo_reminder",
|
||||
additional_kwargs={"hide_from_ui": True},
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"Your todo list from earlier is no longer visible in the current context window, "
|
||||
@@ -113,6 +163,100 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
# Maximum number of completion reminders before allowing the agent to exit.
|
||||
# This prevents infinite loops when the agent cannot make further progress.
|
||||
_MAX_COMPLETION_REMINDERS = 2
|
||||
# Hard cap for per-run reminder bookkeeping in long-lived middleware instances.
|
||||
_MAX_COMPLETION_REMINDER_KEYS = 4096
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._lock = threading.Lock()
|
||||
self._pending_completion_reminders: dict[tuple[str, str], list[str]] = {}
|
||||
self._completion_reminder_counts: dict[tuple[str, str], int] = {}
|
||||
self._completion_reminder_touch_order: dict[tuple[str, str], int] = {}
|
||||
self._completion_reminder_next_order = 0
|
||||
|
||||
@staticmethod
|
||||
def _get_thread_id(runtime: Runtime) -> str:
|
||||
context = getattr(runtime, "context", None)
|
||||
thread_id = context.get("thread_id") if context else None
|
||||
return str(thread_id) if thread_id else "default"
|
||||
|
||||
@staticmethod
|
||||
def _get_run_id(runtime: Runtime) -> str:
|
||||
context = getattr(runtime, "context", None)
|
||||
run_id = context.get("run_id") if context else None
|
||||
return str(run_id) if run_id else "default"
|
||||
|
||||
def _pending_key(self, runtime: Runtime) -> tuple[str, str]:
|
||||
return self._get_thread_id(runtime), self._get_run_id(runtime)
|
||||
|
||||
def _touch_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
||||
self._completion_reminder_next_order += 1
|
||||
self._completion_reminder_touch_order[key] = self._completion_reminder_next_order
|
||||
|
||||
def _completion_reminder_keys_locked(self) -> set[tuple[str, str]]:
|
||||
keys = set(self._pending_completion_reminders)
|
||||
keys.update(self._completion_reminder_counts)
|
||||
keys.update(self._completion_reminder_touch_order)
|
||||
return keys
|
||||
|
||||
def _drop_completion_reminder_key_locked(self, key: tuple[str, str]) -> None:
|
||||
self._pending_completion_reminders.pop(key, None)
|
||||
self._completion_reminder_counts.pop(key, None)
|
||||
self._completion_reminder_touch_order.pop(key, None)
|
||||
|
||||
def _prune_completion_reminder_state_locked(self, protected_key: tuple[str, str]) -> None:
|
||||
keys = self._completion_reminder_keys_locked()
|
||||
overflow = len(keys) - self._MAX_COMPLETION_REMINDER_KEYS
|
||||
if overflow <= 0:
|
||||
return
|
||||
|
||||
candidates = [key for key in keys if key != protected_key]
|
||||
candidates.sort(key=lambda key: self._completion_reminder_touch_order.get(key, 0))
|
||||
for key in candidates[:overflow]:
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
def _queue_completion_reminder(self, runtime: Runtime, reminder: str) -> None:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
self._pending_completion_reminders.setdefault(key, []).append(reminder)
|
||||
self._completion_reminder_counts[key] = self._completion_reminder_counts.get(key, 0) + 1
|
||||
self._touch_completion_reminder_key_locked(key)
|
||||
self._prune_completion_reminder_state_locked(protected_key=key)
|
||||
|
||||
def _completion_reminder_count_for_runtime(self, runtime: Runtime) -> int:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
return self._completion_reminder_counts.get(key, 0)
|
||||
|
||||
def _drain_completion_reminders(self, runtime: Runtime) -> list[str]:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
reminders = self._pending_completion_reminders.pop(key, [])
|
||||
if reminders or key in self._completion_reminder_counts:
|
||||
self._touch_completion_reminder_key_locked(key)
|
||||
return reminders
|
||||
|
||||
def _clear_other_run_completion_reminders(self, runtime: Runtime) -> None:
|
||||
thread_id, current_run_id = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
for key in self._completion_reminder_keys_locked():
|
||||
if key[0] == thread_id and key[1] != current_run_id:
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
def _clear_current_run_completion_reminders(self, runtime: Runtime) -> None:
|
||||
key = self._pending_key(runtime)
|
||||
with self._lock:
|
||||
self._drop_completion_reminder_key_locked(key)
|
||||
|
||||
@override
|
||||
def before_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_other_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def abefore_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_other_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@hook_config(can_jump_to=["model"])
|
||||
@override
|
||||
@@ -137,10 +281,12 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
if base_result is not None:
|
||||
return base_result
|
||||
|
||||
# 2. Only intervene when the agent wants to exit (no tool calls).
|
||||
# 2. Only intervene when the agent wants to exit cleanly. Tool-call
|
||||
# intent or tool-call parse errors should be handled by the tool path
|
||||
# instead of being masked by todo reminders.
|
||||
messages = state.get("messages") or []
|
||||
last_ai = next((m for m in reversed(messages) if isinstance(m, AIMessage)), None)
|
||||
if not last_ai or last_ai.tool_calls:
|
||||
if not last_ai or _has_tool_call_intent_or_error(last_ai):
|
||||
return None
|
||||
|
||||
# 3. Allow exit when all todos are completed or there are no todos.
|
||||
@@ -149,24 +295,14 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
return None
|
||||
|
||||
# 4. Enforce a reminder cap to prevent infinite re-engagement loops.
|
||||
if _completion_reminder_count(messages) >= self._MAX_COMPLETION_REMINDERS:
|
||||
if self._completion_reminder_count_for_runtime(runtime) >= self._MAX_COMPLETION_REMINDERS:
|
||||
return None
|
||||
|
||||
# 5. Inject a reminder and force the agent back to the model.
|
||||
incomplete = [t for t in todos if t.get("status") != "completed"]
|
||||
incomplete_text = "\n".join(f"- [{t.get('status', 'pending')}] {t.get('content', '')}" for t in incomplete)
|
||||
reminder = HumanMessage(
|
||||
name="todo_completion_reminder",
|
||||
content=(
|
||||
"<system_reminder>\n"
|
||||
"You have incomplete todo items that must be finished before giving your final response:\n\n"
|
||||
f"{incomplete_text}\n\n"
|
||||
"Please continue working on these tasks. Call `write_todos` to mark items as completed "
|
||||
"as you finish them, and only respond when all items are done.\n"
|
||||
"</system_reminder>"
|
||||
),
|
||||
)
|
||||
return {"jump_to": "model", "messages": [reminder]}
|
||||
# 5. Queue a reminder for the next model request and jump back. We must
|
||||
# not persist this control prompt as a normal HumanMessage, otherwise it
|
||||
# can leak into user-visible message streams and saved transcripts.
|
||||
self._queue_completion_reminder(runtime, _format_completion_reminder(todos))
|
||||
return {"jump_to": "model"}
|
||||
|
||||
@override
|
||||
@hook_config(can_jump_to=["model"])
|
||||
@@ -177,3 +313,47 @@ class TodoMiddleware(TodoListMiddleware):
|
||||
) -> dict[str, Any] | None:
|
||||
"""Async version of after_model."""
|
||||
return self.after_model(state, runtime)
|
||||
|
||||
@staticmethod
|
||||
def _format_pending_completion_reminders(reminders: list[str]) -> str:
|
||||
return "\n\n".join(dict.fromkeys(reminders))
|
||||
|
||||
def _augment_request(self, request: ModelRequest) -> ModelRequest:
|
||||
reminders = self._drain_completion_reminders(request.runtime)
|
||||
if not reminders:
|
||||
return request
|
||||
new_messages = [
|
||||
*request.messages,
|
||||
HumanMessage(
|
||||
content=self._format_pending_completion_reminders(reminders),
|
||||
name="todo_completion_reminder",
|
||||
additional_kwargs={"hide_from_ui": True},
|
||||
),
|
||||
]
|
||||
return request.override(messages=new_messages)
|
||||
|
||||
@override
|
||||
def wrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], ModelResponse],
|
||||
) -> ModelCallResult:
|
||||
return handler(self._augment_request(request))
|
||||
|
||||
@override
|
||||
async def awrap_model_call(
|
||||
self,
|
||||
request: ModelRequest,
|
||||
handler: Callable[[ModelRequest], Awaitable[ModelResponse]],
|
||||
) -> ModelCallResult:
|
||||
return await handler(self._augment_request(request))
|
||||
|
||||
@override
|
||||
def after_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_current_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@override
|
||||
async def aafter_agent(self, state: PlanningState, runtime: Runtime) -> dict[str, Any] | None:
|
||||
self._clear_current_run_completion_reminders(runtime)
|
||||
return None
|
||||
|
||||
@@ -25,6 +25,7 @@ dependencies = [
|
||||
|
||||
[project.optional-dependencies]
|
||||
postgres = ["deerflow-harness[postgres]"]
|
||||
discord = ["discord.py>=2.7.0"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
|
||||
@@ -4,6 +4,7 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import call_unwrapped, make_authed_test_app
|
||||
from fastapi import HTTPException
|
||||
from fastapi.testclient import TestClient
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import FileResponse
|
||||
@@ -102,3 +103,17 @@ def test_get_artifact_download_true_forces_attachment_for_skill_archive(tmp_path
|
||||
assert response.status_code == 200
|
||||
assert response.text == "hello"
|
||||
assert response.headers.get("content-disposition", "").startswith("attachment;")
|
||||
|
||||
|
||||
def test_skill_archive_preview_rejects_oversized_member_before_decompression(tmp_path) -> None:
|
||||
skill_path = tmp_path / "sample.skill"
|
||||
payload = b"A" * (artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES + 1)
|
||||
with zipfile.ZipFile(skill_path, "w", compression=zipfile.ZIP_DEFLATED, compresslevel=9) as zip_ref:
|
||||
zip_ref.writestr("SKILL.md", payload)
|
||||
|
||||
assert skill_path.stat().st_size < artifacts_router.MAX_SKILL_ARCHIVE_MEMBER_BYTES
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
artifacts_router._extract_file_from_skill_archive(skill_path, "SKILL.md")
|
||||
|
||||
assert exc_info.value.status_code == 413
|
||||
|
||||
@@ -5,28 +5,26 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.gateway.auth.config import AuthConfig
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
|
||||
def test_auth_config_defaults():
|
||||
config = AuthConfig(jwt_secret="test-secret-key-123")
|
||||
config = cfg.AuthConfig(jwt_secret="test-secret-key-123")
|
||||
assert config.token_expiry_days == 7
|
||||
|
||||
|
||||
def test_auth_config_token_expiry_range():
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=1)
|
||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=30)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=0)
|
||||
with pytest.raises(Exception):
|
||||
AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||
cfg.AuthConfig(jwt_secret="s", token_expiry_days=31)
|
||||
|
||||
|
||||
def test_auth_config_from_env():
|
||||
env = {"AUTH_JWT_SECRET": "test-jwt-secret-from-env"}
|
||||
with patch.dict(os.environ, env, clear=False):
|
||||
import app.gateway.auth.config as cfg
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
try:
|
||||
@@ -36,19 +34,57 @@ def test_auth_config_from_env():
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_missing_secret_generates_ephemeral(caplog):
|
||||
def test_auth_config_missing_secret_generates_and_persists(tmp_path, caplog):
|
||||
import logging
|
||||
|
||||
import app.gateway.auth.config as cfg
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
secret_file = tmp_path / ".jwt_secret"
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with caplog.at_level(logging.WARNING):
|
||||
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)), caplog.at_level(logging.WARNING):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert any("AUTH_JWT_SECRET" in msg for msg in caplog.messages)
|
||||
assert secret_file.exists()
|
||||
assert secret_file.read_text().strip() == config.jwt_secret
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_reuses_persisted_secret(tmp_path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
persisted = "persisted-secret-from-file-min-32-chars!!"
|
||||
(tmp_path / ".jwt_secret").write_text(persisted, encoding="utf-8")
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret == persisted
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
|
||||
def test_auth_config_empty_secret_file_generates_new(tmp_path):
|
||||
from deerflow.config.paths import Paths
|
||||
|
||||
old = cfg._auth_config
|
||||
cfg._auth_config = None
|
||||
(tmp_path / ".jwt_secret").write_text("", encoding="utf-8")
|
||||
try:
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
os.environ.pop("AUTH_JWT_SECRET", None)
|
||||
with patch("deerflow.config.paths.get_paths", return_value=Paths(base_dir=tmp_path)):
|
||||
config = cfg.get_auth_config()
|
||||
assert config.jwt_secret
|
||||
assert len(config.jwt_secret) > 20
|
||||
assert (tmp_path / ".jwt_secret").read_text().strip() == config.jwt_secret
|
||||
finally:
|
||||
cfg._auth_config = old
|
||||
|
||||
@@ -761,7 +761,7 @@ class TestChannelManager:
|
||||
|
||||
history_by_checkpoint: dict[tuple[str, str], list[str]] = {}
|
||||
|
||||
async def _runs_wait(thread_id, assistant_id, *, input, config, context):
|
||||
async def _runs_wait(thread_id, assistant_id, *, input, config, context, multitask_strategy=None):
|
||||
del assistant_id, context # unused in this test, kept for signature parity
|
||||
|
||||
checkpoint_ns = config.get("configurable", {}).get("checkpoint_ns")
|
||||
|
||||
@@ -158,6 +158,88 @@ class TestBuildPatchedMessagesPatching:
|
||||
assert patched[1].name == "bash"
|
||||
assert patched[1].status == "error"
|
||||
|
||||
def test_non_adjacent_tool_result_is_moved_next_to_tool_call(self):
|
||||
middleware = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
HumanMessage(content="interruption"),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
patched = middleware._build_patched_messages(msgs)
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert isinstance(patched[2], HumanMessage)
|
||||
|
||||
def test_multiple_tool_results_stay_grouped_after_ai_tool_call(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1"), _tc("read", "call_2")]),
|
||||
HumanMessage(content="interruption"),
|
||||
_tool_msg("call_2", "read"),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert isinstance(patched[2], ToolMessage)
|
||||
assert [patched[1].tool_call_id, patched[2].tool_call_id] == ["call_1", "call_2"]
|
||||
assert isinstance(patched[3], HumanMessage)
|
||||
|
||||
def test_valid_adjacent_tool_results_are_unchanged(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
_tool_msg("call_1", "bash"),
|
||||
HumanMessage(content="next"),
|
||||
]
|
||||
|
||||
assert mw._build_patched_messages(msgs) is None
|
||||
|
||||
def test_tool_results_are_grouped_with_their_own_ai_turn_across_multiple_ai_messages(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
HumanMessage(content="interruption"),
|
||||
_ai_with_tool_calls([_tc("read", "call_2")]),
|
||||
_tool_msg("call_1", "bash"),
|
||||
_tool_msg("call_2", "read"),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert isinstance(patched[2], HumanMessage)
|
||||
assert isinstance(patched[3], AIMessage)
|
||||
assert isinstance(patched[4], ToolMessage)
|
||||
assert patched[4].tool_call_id == "call_2"
|
||||
|
||||
def test_orphan_tool_message_is_preserved_during_grouping(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
orphan = _tool_msg("orphan_call", "orphan")
|
||||
msgs = [
|
||||
_ai_with_tool_calls([_tc("bash", "call_1")]),
|
||||
orphan,
|
||||
HumanMessage(content="interruption"),
|
||||
_tool_msg("call_1", "bash"),
|
||||
]
|
||||
|
||||
patched = mw._build_patched_messages(msgs)
|
||||
|
||||
assert patched is not None
|
||||
assert isinstance(patched[0], AIMessage)
|
||||
assert isinstance(patched[1], ToolMessage)
|
||||
assert patched[1].tool_call_id == "call_1"
|
||||
assert orphan in patched
|
||||
assert patched.count(orphan) == 1
|
||||
|
||||
def test_invalid_tool_call_is_patched(self):
|
||||
mw = DanglingToolCallMiddleware()
|
||||
msgs = [_ai_with_invalid_tool_calls([_invalid_tc()])]
|
||||
|
||||
@@ -454,7 +454,6 @@ class TestAStream:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_with_tools_emits_tool_call_chunk(self):
|
||||
|
||||
tool_calls = [{"name": "fn", "args": {}, "id": "c1"}]
|
||||
with patch.object(MindIEChatModel, "_agenerate", new_callable=AsyncMock) as mock_ag, patch.object(MindIEChatModel, "__init__", return_value=None):
|
||||
mock_ag.return_value = _make_chat_result("ok", tool_calls=tool_calls)
|
||||
|
||||
@@ -56,7 +56,8 @@ def _middleware(
|
||||
preserve_recent_skill_tokens_per_skill: int = 0,
|
||||
) -> DeerFlowSummarizationMiddleware:
|
||||
model = MagicMock()
|
||||
model.invoke.return_value = SimpleNamespace(text="compressed summary")
|
||||
model.invoke.return_value = AIMessage(content="compressed summary")
|
||||
model.with_config.return_value.invoke.return_value = AIMessage(content="compressed summary")
|
||||
return DeerFlowSummarizationMiddleware(
|
||||
model=model,
|
||||
trigger=trigger,
|
||||
@@ -642,6 +643,69 @@ def test_memory_flush_hook_preserves_agent_scoped_memory(monkeypatch: pytest.Mon
|
||||
assert queue.add_nowait.call_args.kwargs["agent_name"] == "research-agent"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Issue #2804: summary text must not leak to the frontend via streaming
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_build_new_messages_sets_hide_from_ui() -> None:
|
||||
"""The summary HumanMessage must carry hide_from_ui so the frontend filters it."""
|
||||
middleware = _middleware()
|
||||
messages = middleware._build_new_messages("test summary")
|
||||
|
||||
assert len(messages) == 1
|
||||
msg = messages[0]
|
||||
assert msg.name == "summary"
|
||||
assert msg.additional_kwargs.get("hide_from_ui") is True
|
||||
assert "test summary" in msg.content
|
||||
|
||||
|
||||
def test_create_summary_suppresses_callbacks() -> None:
|
||||
"""_create_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
||||
in the invoke config to suppress inherited LangGraph stream callbacks."""
|
||||
middleware = _middleware()
|
||||
|
||||
middleware._create_summary(_messages())
|
||||
|
||||
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
||||
bound = middleware.model.with_config.return_value
|
||||
bound.invoke.assert_called_once()
|
||||
call_config = bound.invoke.call_args.kwargs.get("config") or bound.invoke.call_args[1].get("config")
|
||||
assert call_config is not None
|
||||
assert call_config.get("callbacks") == []
|
||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_acreate_summary_suppresses_callbacks() -> None:
|
||||
"""_acreate_summary must bind callbacks=[] on the model AND pass callbacks=[]
|
||||
in the ainvoke config to suppress inherited LangGraph stream callbacks."""
|
||||
middleware = _middleware()
|
||||
middleware.model.with_config.return_value.ainvoke = mock.AsyncMock(return_value=AIMessage(content="async summary"))
|
||||
|
||||
await middleware._acreate_summary(_messages())
|
||||
|
||||
middleware.model.with_config.assert_called_once_with(callbacks=[])
|
||||
bound = middleware.model.with_config.return_value
|
||||
bound.ainvoke.assert_called_once()
|
||||
call_config = bound.ainvoke.call_args.kwargs.get("config") or bound.ainvoke.call_args[1].get("config")
|
||||
assert call_config is not None
|
||||
assert call_config.get("callbacks") == []
|
||||
assert call_config.get("metadata", {}).get("lc_source") == "summarization"
|
||||
|
||||
|
||||
def test_before_model_summary_message_has_hide_from_ui() -> None:
|
||||
"""End-to-end: the emitted state update contains a summary message with hide_from_ui."""
|
||||
middleware = _middleware()
|
||||
|
||||
result = middleware.before_model({"messages": _messages()}, _runtime())
|
||||
|
||||
emitted = result["messages"]
|
||||
summary_msg = emitted[1]
|
||||
assert summary_msg.name == "summary"
|
||||
assert summary_msg.additional_kwargs.get("hide_from_ui") is True
|
||||
|
||||
|
||||
def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
queue = MagicMock()
|
||||
monkeypatch.setattr("deerflow.agents.memory.summarization_hook.get_memory_config", lambda: MemoryConfig(enabled=True))
|
||||
@@ -659,3 +723,17 @@ def test_memory_flush_hook_passes_runtime_user_id(monkeypatch: pytest.MonkeyPatc
|
||||
|
||||
queue.add_nowait.assert_called_once()
|
||||
assert queue.add_nowait.call_args.kwargs["user_id"] == "alice"
|
||||
|
||||
|
||||
def test_extract_summary_text_normalizes_list_content_blocks() -> None:
|
||||
"""AIMessage.content can be a list of content blocks; _extract_summary_text
|
||||
must normalize to plain text via the .text property instead of producing
|
||||
a Python repr like [{'type': 'text', 'text': 'summary'}]."""
|
||||
middleware = _middleware()
|
||||
|
||||
response = AIMessage(content=[{"type": "text", "text": "A summary of the chat."}])
|
||||
assert middleware._extract_summary_text(response) == "A summary of the chat."
|
||||
|
||||
# Plain string content still works
|
||||
response_str = AIMessage(content="Plain summary")
|
||||
assert middleware._extract_summary_text(response_str) == "Plain summary"
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestTitleMiddlewareCoreLogic:
|
||||
assert middleware._should_generate_title(state) is False
|
||||
|
||||
def test_generate_title_uses_async_model_and_respects_max_chars(self, monkeypatch):
|
||||
_set_test_title_config(max_chars=12)
|
||||
_set_test_title_config(max_chars=12, model_name=None)
|
||||
middleware = TitleMiddleware()
|
||||
model = MagicMock()
|
||||
model.ainvoke = AsyncMock(return_value=AIMessage(content="短标题"))
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
"""Tests for TodoMiddleware context-loss detection."""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.language_models.fake_chat_models import FakeMessagesListChatModel
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
from pydantic import PrivateAttr
|
||||
|
||||
from deerflow.agents.middlewares.todo_middleware import (
|
||||
TodoMiddleware,
|
||||
_completion_reminder_count,
|
||||
_format_todos,
|
||||
_has_tool_call_intent_or_error,
|
||||
_reminder_in_messages,
|
||||
_todos_in_messages,
|
||||
)
|
||||
@@ -22,9 +27,35 @@ def _reminder_msg():
|
||||
return HumanMessage(name="todo_reminder", content="reminder")
|
||||
|
||||
|
||||
class _CapturingFakeMessagesListChatModel(FakeMessagesListChatModel):
|
||||
_seen_messages: list[list[Any]] = PrivateAttr(default_factory=list)
|
||||
|
||||
@property
|
||||
def seen_messages(self) -> list[list[Any]]:
|
||||
return self._seen_messages
|
||||
|
||||
def bind_tools(self, tools, *, tool_choice=None, **kwargs):
|
||||
return self
|
||||
|
||||
def _generate(self, messages, stop=None, run_manager=None, **kwargs):
|
||||
self._seen_messages.append(list(messages))
|
||||
return super()._generate(
|
||||
messages,
|
||||
stop=stop,
|
||||
run_manager=run_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _make_runtime():
|
||||
runtime = MagicMock()
|
||||
runtime.context = {"thread_id": "test-thread"}
|
||||
runtime.context = {"thread_id": "test-thread", "run_id": "test-run"}
|
||||
return runtime
|
||||
|
||||
|
||||
def _make_runtime_for(thread_id: str, run_id: str):
|
||||
runtime = _make_runtime()
|
||||
runtime.context = {"thread_id": thread_id, "run_id": run_id}
|
||||
return runtime
|
||||
|
||||
|
||||
@@ -161,10 +192,62 @@ def _completion_reminder_msg():
|
||||
return HumanMessage(name="todo_completion_reminder", content="finish your todos")
|
||||
|
||||
|
||||
def _todo_completion_reminders(messages):
|
||||
reminders = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage) and message.name == "todo_completion_reminder":
|
||||
reminders.append(message)
|
||||
return reminders
|
||||
|
||||
|
||||
def _ai_no_tool_calls():
|
||||
return AIMessage(content="I'm done!")
|
||||
|
||||
|
||||
def _ai_with_invalid_tool_calls():
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[],
|
||||
invalid_tool_calls=[
|
||||
{
|
||||
"type": "invalid_tool_call",
|
||||
"id": "write_file:36",
|
||||
"name": "write_file",
|
||||
"args": "{invalid",
|
||||
"error": "Failed to parse tool arguments",
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _ai_with_raw_provider_tool_calls():
|
||||
return AIMessage(
|
||||
content="",
|
||||
tool_calls=[],
|
||||
invalid_tool_calls=[],
|
||||
additional_kwargs={
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "raw-tool-call",
|
||||
"type": "function",
|
||||
"function": {"name": "write_file", "arguments": '{"path":"report.md"}'},
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _ai_with_legacy_function_call():
|
||||
return AIMessage(
|
||||
content="",
|
||||
additional_kwargs={"function_call": {"name": "write_file", "arguments": '{"path":"report.md"}'}},
|
||||
)
|
||||
|
||||
|
||||
def _ai_with_tool_finish_reason():
|
||||
return AIMessage(content="", response_metadata={"finish_reason": "tool_calls"})
|
||||
|
||||
|
||||
def _incomplete_todos():
|
||||
return [
|
||||
{"status": "completed", "content": "Step 1"},
|
||||
@@ -194,6 +277,36 @@ class TestCompletionReminderCount:
|
||||
assert _completion_reminder_count(msgs) == 1
|
||||
|
||||
|
||||
class TestToolCallIntentOrError:
|
||||
def test_false_for_plain_final_answer(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_no_tool_calls()) is False
|
||||
|
||||
def test_true_for_structured_tool_calls(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_write_todos()) is True
|
||||
|
||||
def test_true_for_invalid_tool_calls(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_invalid_tool_calls()) is True
|
||||
|
||||
def test_true_for_raw_provider_tool_calls(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_raw_provider_tool_calls()) is True
|
||||
|
||||
def test_true_for_legacy_function_call(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_legacy_function_call()) is True
|
||||
|
||||
def test_true_for_tool_finish_reason(self):
|
||||
assert _has_tool_call_intent_or_error(_ai_with_tool_finish_reason()) is True
|
||||
|
||||
def test_langchain_ai_message_tool_fields_are_explicitly_handled(self):
|
||||
# Sentinel for LangChain compatibility: if future AIMessage versions add
|
||||
# new top-level tool/function-call fields, this test should fail. When
|
||||
# it does, update `_has_tool_call_intent_or_error()` so the completion
|
||||
# reminder guard explicitly decides whether each new field means "not a
|
||||
# clean final answer"; the helper has a matching comment pointing back
|
||||
# to this sentinel.
|
||||
tool_related_fields = {name for name in AIMessage.model_fields if "tool" in name.lower() or ("function" in name.lower() and "call" in name.lower())}
|
||||
assert tool_related_fields <= {"tool_calls", "invalid_tool_calls"}
|
||||
|
||||
|
||||
class TestAfterModel:
|
||||
def test_returns_none_when_agent_still_using_tools(self):
|
||||
mw = TodoMiddleware()
|
||||
@@ -235,68 +348,299 @@ class TestAfterModel:
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_injects_reminder_and_jumps_to_model_when_incomplete(self):
|
||||
def test_queues_reminder_and_jumps_to_model_when_incomplete(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [HumanMessage(content="hi"), _ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = mw.after_model(state, _make_runtime())
|
||||
result = mw.after_model(state, runtime)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
assert len(result["messages"]) == 1
|
||||
reminder = result["messages"][0]
|
||||
assert "messages" not in result
|
||||
|
||||
request = MagicMock()
|
||||
request.runtime = runtime
|
||||
request.messages = state["messages"]
|
||||
request.override.return_value = "patched-request"
|
||||
handler = MagicMock(return_value="response")
|
||||
|
||||
assert mw.wrap_model_call(request, handler) == "response"
|
||||
request.override.assert_called_once()
|
||||
reminder = request.override.call_args.kwargs["messages"][-1]
|
||||
assert isinstance(reminder, HumanMessage)
|
||||
assert reminder.name == "todo_completion_reminder"
|
||||
assert reminder.additional_kwargs["hide_from_ui"] is True
|
||||
assert "Step 2" in reminder.content
|
||||
assert "Step 3" in reminder.content
|
||||
handler.assert_called_once_with("patched-request")
|
||||
|
||||
def test_reminder_lists_only_incomplete_items(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = mw.after_model(state, _make_runtime())
|
||||
content = result["messages"][0].content
|
||||
result = mw.after_model(state, runtime)
|
||||
assert result is not None
|
||||
|
||||
request = MagicMock()
|
||||
request.runtime = runtime
|
||||
request.messages = state["messages"]
|
||||
request.override.return_value = "patched-request"
|
||||
mw.wrap_model_call(request, MagicMock(return_value="response"))
|
||||
content = request.override.call_args.kwargs["messages"][-1].content
|
||||
assert "Step 1" not in content # completed — should not appear
|
||||
assert "Step 2" in content
|
||||
assert "Step 3" in content
|
||||
|
||||
def test_allows_exit_after_max_reminders(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [
|
||||
_completion_reminder_msg(),
|
||||
_completion_reminder_msg(),
|
||||
_ai_no_tool_calls(),
|
||||
],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, runtime) is not None
|
||||
assert mw.after_model(state, runtime) is not None
|
||||
assert mw.after_model(state, runtime) is None
|
||||
|
||||
def test_still_sends_reminder_before_cap(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [
|
||||
_ai_no_tool_calls(),
|
||||
],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, runtime) is not None
|
||||
result = mw.after_model(state, runtime)
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
|
||||
def test_does_not_trigger_for_invalid_tool_calls(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_with_invalid_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_still_sends_reminder_before_cap(self):
|
||||
def test_does_not_trigger_for_raw_provider_tool_calls(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [
|
||||
_completion_reminder_msg(), # 1 reminder so far
|
||||
_ai_no_tool_calls(),
|
||||
],
|
||||
"messages": [_ai_with_raw_provider_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = mw.after_model(state, _make_runtime())
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_does_not_trigger_for_legacy_function_call(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_with_legacy_function_call()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
def test_does_not_trigger_for_tool_finish_reason(self):
|
||||
mw = TodoMiddleware()
|
||||
state = {
|
||||
"messages": [_ai_with_tool_finish_reason()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
assert mw.after_model(state, _make_runtime()) is None
|
||||
|
||||
|
||||
class TestAafterModel:
|
||||
def test_delegates_to_sync(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
result = asyncio.run(mw.aafter_model(state, _make_runtime()))
|
||||
result = asyncio.run(mw.aafter_model(state, runtime))
|
||||
assert result is not None
|
||||
assert result["jump_to"] == "model"
|
||||
assert result["messages"][0].name == "todo_completion_reminder"
|
||||
assert "messages" not in result
|
||||
|
||||
|
||||
class TestWrapModelCall:
|
||||
def test_no_pending_reminder_passthrough(self):
|
||||
mw = TodoMiddleware()
|
||||
request = MagicMock()
|
||||
request.runtime = _make_runtime()
|
||||
request.messages = [HumanMessage(content="hi")]
|
||||
handler = MagicMock(return_value="response")
|
||||
|
||||
assert mw.wrap_model_call(request, handler) == "response"
|
||||
request.override.assert_not_called()
|
||||
handler.assert_called_once_with(request)
|
||||
|
||||
def test_pending_reminder_is_injected_once(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
mw.after_model(state, runtime)
|
||||
|
||||
request = MagicMock()
|
||||
request.runtime = runtime
|
||||
request.messages = state["messages"]
|
||||
request.override.return_value = "patched-request"
|
||||
handler = MagicMock(return_value="response")
|
||||
|
||||
assert mw.wrap_model_call(request, handler) == "response"
|
||||
injected_messages = request.override.call_args.kwargs["messages"]
|
||||
assert injected_messages[-1].name == "todo_completion_reminder"
|
||||
|
||||
request.override.reset_mock()
|
||||
handler.reset_mock()
|
||||
handler.return_value = "second-response"
|
||||
assert mw.wrap_model_call(request, handler) == "second-response"
|
||||
request.override.assert_not_called()
|
||||
handler.assert_called_once_with(request)
|
||||
|
||||
|
||||
class TestTodoMiddlewareAgentGraphIntegration:
|
||||
def test_completion_reminder_is_transient_in_real_agent_graph(self):
|
||||
mw = TodoMiddleware()
|
||||
model = _CapturingFakeMessagesListChatModel(
|
||||
responses=[
|
||||
AIMessage(
|
||||
content="",
|
||||
tool_calls=[
|
||||
{
|
||||
"name": "write_todos",
|
||||
"id": "todos-1",
|
||||
"args": {
|
||||
"todos": [
|
||||
{"content": "Step 1", "status": "completed"},
|
||||
{"content": "Step 2", "status": "pending"},
|
||||
]
|
||||
},
|
||||
}
|
||||
],
|
||||
),
|
||||
AIMessage(content="premature final 1"),
|
||||
AIMessage(content="premature final 2"),
|
||||
AIMessage(content="premature final 3"),
|
||||
],
|
||||
)
|
||||
graph = create_agent(model=model, tools=[], middleware=[mw])
|
||||
|
||||
result = graph.invoke(
|
||||
{"messages": [("user", "finish all todos")]},
|
||||
context={"thread_id": "integration-thread", "run_id": "integration-run"},
|
||||
)
|
||||
|
||||
assert len(model.seen_messages) == 4
|
||||
reminders_by_call = [_todo_completion_reminders(messages) for messages in model.seen_messages]
|
||||
assert reminders_by_call[0] == []
|
||||
assert reminders_by_call[1] == []
|
||||
assert len(reminders_by_call[2]) == 1
|
||||
assert len(reminders_by_call[3]) == 1
|
||||
assert "Step 1" not in reminders_by_call[2][0].content
|
||||
assert "Step 2" in reminders_by_call[2][0].content
|
||||
|
||||
persisted_reminders = _todo_completion_reminders(result["messages"])
|
||||
assert persisted_reminders == []
|
||||
assert result["messages"][-1].content == "premature final 3"
|
||||
assert result["todos"] == [
|
||||
{"content": "Step 1", "status": "completed"},
|
||||
{"content": "Step 2", "status": "pending"},
|
||||
]
|
||||
assert mw._pending_completion_reminders == {}
|
||||
assert mw._completion_reminder_counts == {}
|
||||
|
||||
|
||||
class TestRunScopedReminderCleanup:
|
||||
def test_before_agent_clears_stale_count_without_pending_reminder(self):
|
||||
mw = TodoMiddleware()
|
||||
stale_runtime = _make_runtime()
|
||||
stale_runtime.context = {"thread_id": "test-thread", "run_id": "stale-run"}
|
||||
current_runtime = _make_runtime()
|
||||
current_runtime.context = {"thread_id": "test-thread", "run_id": "current-run"}
|
||||
other_thread_runtime = _make_runtime()
|
||||
other_thread_runtime.context = {"thread_id": "other-thread", "run_id": "stale-run"}
|
||||
|
||||
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
||||
assert mw.after_model(state, stale_runtime) is not None
|
||||
assert mw.after_model(state, other_thread_runtime) is not None
|
||||
|
||||
# Simulate a model call that drained the pending message, followed by an
|
||||
# abnormal run end where after_agent did not clear the reminder count.
|
||||
assert mw._drain_completion_reminders(stale_runtime)
|
||||
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 1
|
||||
|
||||
mw.before_agent({}, current_runtime)
|
||||
|
||||
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
|
||||
assert mw._completion_reminder_count_for_runtime(other_thread_runtime) == 1
|
||||
|
||||
def test_size_guard_prunes_oldest_count_only_reminder_state(self):
|
||||
mw = TodoMiddleware()
|
||||
mw._MAX_COMPLETION_REMINDER_KEYS = 2
|
||||
first_runtime = _make_runtime_for("thread-a", "run-a")
|
||||
second_runtime = _make_runtime_for("thread-b", "run-b")
|
||||
third_runtime = _make_runtime_for("thread-c", "run-c")
|
||||
|
||||
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
||||
assert mw.after_model(state, first_runtime) is not None
|
||||
|
||||
# Simulate the normal model request path: pending reminder is consumed,
|
||||
# but the run count remains until after_agent() or stale cleanup.
|
||||
assert mw._drain_completion_reminders(first_runtime)
|
||||
assert mw._completion_reminder_count_for_runtime(first_runtime) == 1
|
||||
|
||||
assert mw.after_model(state, second_runtime) is not None
|
||||
assert mw.after_model(state, third_runtime) is not None
|
||||
|
||||
assert mw._completion_reminder_count_for_runtime(first_runtime) == 0
|
||||
assert mw._completion_reminder_count_for_runtime(second_runtime) == 1
|
||||
assert mw._completion_reminder_count_for_runtime(third_runtime) == 1
|
||||
assert ("thread-a", "run-a") not in mw._completion_reminder_touch_order
|
||||
|
||||
def test_size_guard_prunes_pending_and_count_state_together(self):
|
||||
mw = TodoMiddleware()
|
||||
mw._MAX_COMPLETION_REMINDER_KEYS = 1
|
||||
stale_runtime = _make_runtime_for("thread-a", "run-a")
|
||||
current_runtime = _make_runtime_for("thread-b", "run-b")
|
||||
|
||||
state = {"messages": [_ai_no_tool_calls()], "todos": _incomplete_todos()}
|
||||
assert mw.after_model(state, stale_runtime) is not None
|
||||
assert mw.after_model(state, current_runtime) is not None
|
||||
|
||||
assert mw._drain_completion_reminders(stale_runtime) == []
|
||||
assert mw._completion_reminder_count_for_runtime(stale_runtime) == 0
|
||||
assert mw._completion_reminder_count_for_runtime(current_runtime) == 1
|
||||
|
||||
|
||||
class TestAwrapModelCall:
|
||||
def test_async_pending_reminder_is_injected(self):
|
||||
mw = TodoMiddleware()
|
||||
runtime = _make_runtime()
|
||||
state = {
|
||||
"messages": [_ai_no_tool_calls()],
|
||||
"todos": _incomplete_todos(),
|
||||
}
|
||||
mw.after_model(state, runtime)
|
||||
|
||||
request = MagicMock()
|
||||
request.runtime = runtime
|
||||
request.messages = state["messages"]
|
||||
request.override.return_value = "patched-request"
|
||||
handler = AsyncMock(return_value="response")
|
||||
|
||||
result = asyncio.run(mw.awrap_model_call(request, handler))
|
||||
assert result == "response"
|
||||
injected_messages = request.override.call_args.kwargs["messages"]
|
||||
assert injected_messages[-1].name == "todo_completion_reminder"
|
||||
handler.assert_awaited_once_with("patched-request")
|
||||
|
||||
Generated
+19
-2
@@ -1,5 +1,5 @@
|
||||
version = 1
|
||||
revision = 3
|
||||
revision = 2
|
||||
requires-python = ">=3.12"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.14' and sys_platform == 'win32'",
|
||||
@@ -763,6 +763,9 @@ dependencies = [
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
discord = [
|
||||
{ name = "discord-py" },
|
||||
]
|
||||
postgres = [
|
||||
{ name = "deerflow-harness", extra = ["postgres"] },
|
||||
]
|
||||
@@ -781,6 +784,7 @@ requires-dist = [
|
||||
{ name = "deerflow-harness", editable = "packages/harness" },
|
||||
{ name = "deerflow-harness", extras = ["postgres"], marker = "extra == 'postgres'", editable = "packages/harness" },
|
||||
{ name = "dingtalk-stream", specifier = ">=0.24.3" },
|
||||
{ name = "discord-py", marker = "extra == 'discord'", specifier = ">=2.7.0" },
|
||||
{ name = "email-validator", specifier = ">=2.0.0" },
|
||||
{ name = "fastapi", specifier = ">=0.115.0" },
|
||||
{ name = "httpx", specifier = ">=0.28.0" },
|
||||
@@ -795,7 +799,7 @@ requires-dist = [
|
||||
{ name = "uvicorn", extras = ["standard"], specifier = ">=0.34.0" },
|
||||
{ name = "wecom-aibot-python-sdk", specifier = ">=0.1.6" },
|
||||
]
|
||||
provides-extras = ["postgres"]
|
||||
provides-extras = ["postgres", "discord"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
@@ -923,6 +927,19 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/4c/44/102dede3f371277598df6aa9725b82e3add068c729333c7a5dbc12764579/dingtalk_stream-0.24.3-py3-none-any.whl", hash = "sha256:2160403656985962878bf60cdf5adf41619f21067348e06f07a7c7eebf5943ad", size = 27813, upload-time = "2025-10-24T09:36:57.497Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "discord-py"
|
||||
version = "2.7.1"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "aiohttp" },
|
||||
{ name = "audioop-lts", marker = "python_full_version >= '3.13'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ef/57/9a2d9abdabdc9db8ef28ce0cf4129669e1c8717ba28d607b5ba357c4de3b/discord_py-2.7.1.tar.gz", hash = "sha256:24d5e6a45535152e4b98148a9dd6b550d25dc2c9fb41b6d670319411641249da", size = 1106326, upload-time = "2026-03-03T18:40:46.24Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/f7/a7/17208c3b3f92319e7fad259f1c6d5a5baf8fd0654c54846ced329f83c3eb/discord_py-2.7.1-py3-none-any.whl", hash = "sha256:849dca2c63b171146f3a7f3f8acc04248098e9e6203412ce3cf2745f284f7439", size = 1227550, upload-time = "2026-03-03T18:40:44.492Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "distro"
|
||||
version = "1.9.0"
|
||||
|
||||
@@ -1029,6 +1029,14 @@ run_events:
|
||||
# client_secret: $DINGTALK_CLIENT_SECRET
|
||||
# allowed_users: [] # empty = allow all
|
||||
# card_template_id: "" # Optional: AI Card template ID for streaming updates
|
||||
#
|
||||
# discord:
|
||||
# enabled: false
|
||||
# bot_token: $DISCORD_BOT_TOKEN
|
||||
# allowed_guilds: [] # empty = allow all guilds; can also be a single guild ID
|
||||
# mention_only: false # If true, only respond when the bot is mentioned
|
||||
# allowed_channels: [] # Optional: channel IDs exempt from mention_only (bot responds without mention)
|
||||
# thread_mode: false # If true, group a channel conversation into a thread
|
||||
|
||||
# ============================================================================
|
||||
# Guardrails Configuration
|
||||
|
||||
+21
-3
@@ -28,6 +28,10 @@ http {
|
||||
set $gateway_upstream gateway:8001;
|
||||
set $frontend_upstream frontend:3000;
|
||||
|
||||
# Default proxy settings for all locations (streaming/SSE support)
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
|
||||
# Keep the unified nginx endpoint same-origin by default. When split
|
||||
# frontend/backend or port-forwarded deployments need browser CORS,
|
||||
# configure the Gateway allowlist with GATEWAY_CORS_ORIGINS so CORS and
|
||||
@@ -49,8 +53,6 @@ http {
|
||||
proxy_set_header Connection '';
|
||||
|
||||
# SSE/Streaming support
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
proxy_set_header X-Accel-Buffering no;
|
||||
|
||||
# Timeouts for long-running requests
|
||||
@@ -70,6 +72,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: Memory endpoint
|
||||
@@ -80,6 +83,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: MCP configuration endpoint
|
||||
@@ -90,6 +94,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: Skills configuration endpoint
|
||||
@@ -100,6 +105,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: Agents endpoint
|
||||
@@ -110,6 +116,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Custom API: Uploads endpoint
|
||||
@@ -124,6 +131,8 @@ http {
|
||||
# Large file upload support
|
||||
client_max_body_size 100M;
|
||||
proxy_request_buffering off;
|
||||
|
||||
# Disable response buffering to avoid permission errors
|
||||
}
|
||||
|
||||
# Custom API: Other endpoints under /api/threads
|
||||
@@ -134,6 +143,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# API Documentation: Swagger UI
|
||||
@@ -144,6 +154,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# API Documentation: ReDoc
|
||||
@@ -154,6 +165,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# API Documentation: OpenAPI Schema
|
||||
@@ -164,6 +176,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Health check endpoint (gateway)
|
||||
@@ -174,6 +187,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# ── Provisioner API (sandbox management) ────────────────────────
|
||||
@@ -187,6 +201,7 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
}
|
||||
|
||||
# Catch-all for /api/ routes not covered above (e.g. /api/v1/auth/*).
|
||||
@@ -198,6 +213,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# Disable buffering to avoid permission errors when nginx
|
||||
# runs as a non-root user (e.g. local development).
|
||||
}
|
||||
|
||||
# All other requests go to frontend
|
||||
@@ -220,4 +238,4 @@ http {
|
||||
proxy_read_timeout 600s;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -70,6 +70,11 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
# Disable buffering to avoid permission errors when nginx
|
||||
# runs as a non-root user (e.g. local development).
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Memory endpoint
|
||||
@@ -80,6 +85,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: MCP configuration endpoint
|
||||
@@ -90,6 +98,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Skills configuration endpoint
|
||||
@@ -100,6 +111,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Agents endpoint
|
||||
@@ -110,6 +124,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Uploads endpoint
|
||||
@@ -124,6 +141,10 @@ http {
|
||||
# Large file upload support
|
||||
client_max_body_size 100M;
|
||||
proxy_request_buffering off;
|
||||
|
||||
# Disable response buffering to avoid permission errors
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Custom API: Other endpoints under /api/threads
|
||||
@@ -134,6 +155,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# API Documentation: Swagger UI
|
||||
@@ -144,6 +168,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# API Documentation: ReDoc
|
||||
@@ -154,6 +181,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# API Documentation: OpenAPI Schema
|
||||
@@ -164,6 +194,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Health check endpoint (gateway)
|
||||
@@ -174,6 +207,9 @@ http {
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
|
||||
proxy_set_header X-Forwarded-Proto $scheme;
|
||||
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# Catch-all for any /api/* prefix not matched by a more specific block above.
|
||||
@@ -193,6 +229,11 @@ http {
|
||||
# Auth endpoints set HttpOnly cookies — make sure nginx doesn't
|
||||
# strip the Set-Cookie header from upstream responses.
|
||||
proxy_pass_header Set-Cookie;
|
||||
|
||||
# Disable buffering to avoid permission errors when nginx
|
||||
# runs as a non-root user (e.g. local development).
|
||||
proxy_buffering off;
|
||||
proxy_cache off;
|
||||
}
|
||||
|
||||
# All other requests go to frontend
|
||||
|
||||
@@ -66,6 +66,7 @@ export default function AgentChatPage() {
|
||||
thread,
|
||||
pendingUsageMessages,
|
||||
sendMessage,
|
||||
isUploading,
|
||||
isHistoryLoading,
|
||||
hasMoreHistory,
|
||||
loadMoreHistory,
|
||||
@@ -106,7 +107,11 @@ export default function AgentChatPage() {
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
(message: PromptInputMessage) => {
|
||||
void sendMessage(threadId, message, { agent_name });
|
||||
const sendPromise = sendMessage(threadId, message, { agent_name });
|
||||
if (message.files.length > 0) {
|
||||
return sendPromise;
|
||||
}
|
||||
void sendPromise;
|
||||
},
|
||||
[sendMessage, threadId, agent_name],
|
||||
);
|
||||
@@ -243,7 +248,10 @@ export default function AgentChatPage() {
|
||||
<AgentWelcome agent={agent} agentName={agent_name} />
|
||||
)
|
||||
}
|
||||
disabled={env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true"}
|
||||
disabled={
|
||||
env.NEXT_PUBLIC_STATIC_WEBSITE_ONLY === "true" ||
|
||||
isUploading
|
||||
}
|
||||
onContextChange={(context) => setSettings("context", context)}
|
||||
onSubmit={handleSubmit}
|
||||
onStop={handleStop}
|
||||
|
||||
@@ -109,7 +109,11 @@ export default function ChatPage() {
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
(message: PromptInputMessage) => {
|
||||
void sendMessage(threadId, message);
|
||||
const sendPromise = sendMessage(threadId, message);
|
||||
if (message.files.length > 0) {
|
||||
return sendPromise;
|
||||
}
|
||||
void sendPromise;
|
||||
},
|
||||
[sendMessage, threadId],
|
||||
);
|
||||
|
||||
@@ -499,6 +499,10 @@ export const PromptInput = ({
|
||||
// Keep a ref to files for cleanup on unmount (avoids stale closure)
|
||||
const filesRef = useRef(files);
|
||||
filesRef.current = files;
|
||||
const providerTextRef = useRef("");
|
||||
if (usingProvider) {
|
||||
providerTextRef.current = controller.textInput.value;
|
||||
}
|
||||
|
||||
const openFileDialogLocal = useCallback(() => {
|
||||
inputRef.current?.click();
|
||||
@@ -768,6 +772,24 @@ export const PromptInput = ({
|
||||
}
|
||||
|
||||
// Convert blob URLs to data URLs asynchronously
|
||||
const submittedFileIds = files.map((file) => file.id);
|
||||
const clearSubmittedState = () => {
|
||||
const currentFileIds = new Set(filesRef.current.map((file) => file.id));
|
||||
const submittedFileIdsStillPresent = submittedFileIds.filter((id) =>
|
||||
currentFileIds.has(id),
|
||||
);
|
||||
if (submittedFileIdsStillPresent.length === filesRef.current.length) {
|
||||
clear();
|
||||
} else {
|
||||
for (const id of submittedFileIdsStillPresent) {
|
||||
remove(id);
|
||||
}
|
||||
}
|
||||
if (usingProvider && providerTextRef.current === text) {
|
||||
controller.textInput.clear();
|
||||
}
|
||||
};
|
||||
|
||||
Promise.all(
|
||||
files.map(async ({ id, ...item }) => {
|
||||
if (item.file instanceof File) {
|
||||
@@ -793,20 +815,14 @@ export const PromptInput = ({
|
||||
if (result instanceof Promise) {
|
||||
result
|
||||
.then(() => {
|
||||
clear();
|
||||
if (usingProvider) {
|
||||
controller.textInput.clear();
|
||||
}
|
||||
clearSubmittedState();
|
||||
})
|
||||
.catch(() => {
|
||||
// Don't clear on error - user may want to retry
|
||||
});
|
||||
} else {
|
||||
// Sync function completed without throwing, clear attachments
|
||||
clear();
|
||||
if (usingProvider) {
|
||||
controller.textInput.clear();
|
||||
}
|
||||
clearSubmittedState();
|
||||
}
|
||||
} catch {
|
||||
// Don't clear on error - user may want to retry
|
||||
|
||||
@@ -110,6 +110,7 @@ export function InputBox({
|
||||
threadId,
|
||||
initialValue,
|
||||
onContextChange,
|
||||
onFollowupsVisibilityChange,
|
||||
onSubmit,
|
||||
onStop,
|
||||
...props
|
||||
@@ -142,7 +143,8 @@ export function InputBox({
|
||||
reasoning_effort?: "minimal" | "low" | "medium" | "high";
|
||||
},
|
||||
) => void;
|
||||
onSubmit?: (message: PromptInputMessage) => void;
|
||||
onFollowupsVisibilityChange?: (visible: boolean) => void;
|
||||
onSubmit?: (message: PromptInputMessage) => void | Promise<void>;
|
||||
onStop?: () => void;
|
||||
}) {
|
||||
const { t } = useI18n();
|
||||
@@ -251,12 +253,12 @@ export function InputBox({
|
||||
);
|
||||
|
||||
const handleSubmit = useCallback(
|
||||
async (message: PromptInputMessage) => {
|
||||
(message: PromptInputMessage) => {
|
||||
if (status === "streaming") {
|
||||
onStop?.();
|
||||
return;
|
||||
}
|
||||
if (!message.text) {
|
||||
if (!message.text.trim() && message.files.length === 0) {
|
||||
return;
|
||||
}
|
||||
setFollowups([]);
|
||||
@@ -274,11 +276,14 @@ export function InputBox({
|
||||
selectedModel?.supports_thinking ?? false,
|
||||
),
|
||||
});
|
||||
setTimeout(() => onSubmit?.(message), 0);
|
||||
return;
|
||||
return new Promise<void>((resolve, reject) => {
|
||||
setTimeout(() => {
|
||||
Promise.resolve(onSubmit?.(message)).then(resolve).catch(reject);
|
||||
}, 0);
|
||||
});
|
||||
}
|
||||
|
||||
onSubmit?.(message);
|
||||
return onSubmit?.(message);
|
||||
},
|
||||
[
|
||||
context,
|
||||
@@ -348,6 +353,14 @@ export function InputBox({
|
||||
!followupsHidden &&
|
||||
(followupsLoading || followups.length > 0);
|
||||
|
||||
useEffect(() => {
|
||||
onFollowupsVisibilityChange?.(showFollowups);
|
||||
}, [onFollowupsVisibilityChange, showFollowups]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => onFollowupsVisibilityChange?.(false);
|
||||
}, [onFollowupsVisibilityChange]);
|
||||
|
||||
useEffect(() => {
|
||||
messagesRef.current = thread.messages;
|
||||
}, [thread.messages]);
|
||||
|
||||
@@ -26,6 +26,13 @@ export type MessageGroup =
|
||||
| AssistantClarificationGroup
|
||||
| AssistantSubagentGroup;
|
||||
|
||||
const HIDDEN_CONTROL_MESSAGE_NAMES = new Set([
|
||||
"summary",
|
||||
"loop_warning",
|
||||
"todo_reminder",
|
||||
"todo_completion_reminder",
|
||||
]);
|
||||
|
||||
export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
||||
if (messages.length === 0) {
|
||||
return [];
|
||||
@@ -53,10 +60,6 @@ export function getMessageGroups(messages: Message[]): MessageGroup[] {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (message.name === "todo_reminder") {
|
||||
continue;
|
||||
}
|
||||
|
||||
if (message.type === "human") {
|
||||
groups.push({ id: message.id, type: "human", messages: [message] });
|
||||
continue;
|
||||
@@ -368,8 +371,8 @@ export function findToolCallResult(toolCallId: string, messages: Message[]) {
|
||||
export function isHiddenFromUIMessage(message: Message) {
|
||||
return (
|
||||
message.additional_kwargs?.hide_from_ui === true ||
|
||||
message.name === "summary" ||
|
||||
message.name === "loop_warning"
|
||||
(typeof message.name === "string" &&
|
||||
HIDDEN_CONTROL_MESSAGE_NAMES.has(message.name))
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -45,15 +45,60 @@ type SendMessageOptions = {
|
||||
additionalKwargs?: Record<string, unknown>;
|
||||
};
|
||||
|
||||
function mergeMessages(
|
||||
function isNonEmptyString(value: string | undefined): value is string {
|
||||
return typeof value === "string" && value.length > 0;
|
||||
}
|
||||
|
||||
function messageIdentity(message: Message): string | undefined {
|
||||
if (
|
||||
"tool_call_id" in message &&
|
||||
typeof message.tool_call_id === "string" &&
|
||||
message.tool_call_id.length > 0
|
||||
) {
|
||||
return `tool:${message.tool_call_id}`;
|
||||
}
|
||||
if (typeof message.id === "string" && message.id.length > 0) {
|
||||
return `message:${message.id}`;
|
||||
}
|
||||
return undefined;
|
||||
}
|
||||
|
||||
function dedupeMessagesByIdentity(messages: Message[]): Message[] {
|
||||
const lastIndexByIdentity = new Map<string, number>();
|
||||
|
||||
messages.forEach((message, index) => {
|
||||
const identity = messageIdentity(message);
|
||||
if (identity) {
|
||||
lastIndexByIdentity.set(identity, index);
|
||||
}
|
||||
});
|
||||
|
||||
return messages.filter((message, index) => {
|
||||
const identity = messageIdentity(message);
|
||||
return !identity || lastIndexByIdentity.get(identity) === index;
|
||||
});
|
||||
}
|
||||
|
||||
function findLatestUnloadedRunIndex(
|
||||
runs: Run[],
|
||||
loadedRunIds: ReadonlySet<string>,
|
||||
): number {
|
||||
for (let i = runs.length - 1; i >= 0; i--) {
|
||||
const run = runs[i];
|
||||
if (run && !loadedRunIds.has(run.run_id)) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
export function mergeMessages(
|
||||
historyMessages: Message[],
|
||||
threadMessages: Message[],
|
||||
optimisticMessages: Message[],
|
||||
): Message[] {
|
||||
const threadMessageIds = new Set(
|
||||
threadMessages
|
||||
.map((m) => ("tool_call_id" in m ? m.tool_call_id : m.id))
|
||||
.filter(Boolean),
|
||||
threadMessages.map(messageIdentity).filter(isNonEmptyString),
|
||||
);
|
||||
|
||||
// The overlap is a contiguous suffix of historyMessages (newest history == oldest thread).
|
||||
@@ -65,28 +110,19 @@ function mergeMessages(
|
||||
if (!msg) {
|
||||
continue;
|
||||
}
|
||||
if (
|
||||
(msg?.id && threadMessageIds.has(msg.id)) ||
|
||||
("tool_call_id" in msg && threadMessageIds.has(msg.tool_call_id))
|
||||
) {
|
||||
const identity = messageIdentity(msg);
|
||||
if (identity && threadMessageIds.has(identity)) {
|
||||
cutoff = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
return [
|
||||
return dedupeMessagesByIdentity([
|
||||
...historyMessages.slice(0, cutoff),
|
||||
...threadMessages,
|
||||
...optimisticMessages,
|
||||
];
|
||||
}
|
||||
|
||||
function messageIdentity(message: Message): string | undefined {
|
||||
if ("tool_call_id" in message) {
|
||||
return message.tool_call_id;
|
||||
}
|
||||
return message.id;
|
||||
]);
|
||||
}
|
||||
|
||||
function getMessagesAfterBaseline(
|
||||
@@ -627,48 +663,105 @@ export function useThreadHistory(threadId: string) {
|
||||
const runsRef = useRef(runs.data ?? []);
|
||||
const indexRef = useRef(-1);
|
||||
const loadingRef = useRef(false);
|
||||
const pendingLoadRef = useRef(false);
|
||||
const loadingRunIdRef = useRef<string | null>(null);
|
||||
const loadedRunIdsRef = useRef<Set<string>>(new Set());
|
||||
const [loading, setLoading] = useState(false);
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
|
||||
loadingRef.current = loading;
|
||||
const loadMessages = useCallback(async () => {
|
||||
if (loadingRef.current) {
|
||||
const pendingRunIndex = findLatestUnloadedRunIndex(
|
||||
runsRef.current,
|
||||
loadedRunIdsRef.current,
|
||||
);
|
||||
const pendingRun = runsRef.current[pendingRunIndex];
|
||||
if (pendingRun && pendingRun.run_id !== loadingRunIdRef.current) {
|
||||
pendingLoadRef.current = true;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (runsRef.current.length === 0) {
|
||||
return;
|
||||
}
|
||||
const run = runsRef.current[indexRef.current];
|
||||
if (!run || loadingRef.current) {
|
||||
return;
|
||||
}
|
||||
|
||||
loadingRef.current = true;
|
||||
setLoading(true);
|
||||
|
||||
try {
|
||||
setLoading(true);
|
||||
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(threadIdRef.current)}/runs/${encodeURIComponent(run.run_id)}/messages`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
do {
|
||||
pendingLoadRef.current = false;
|
||||
|
||||
const nextRunIndex = findLatestUnloadedRunIndex(
|
||||
runsRef.current,
|
||||
loadedRunIdsRef.current,
|
||||
);
|
||||
indexRef.current = nextRunIndex;
|
||||
|
||||
const run = runsRef.current[nextRunIndex];
|
||||
if (!run) {
|
||||
indexRef.current = -1;
|
||||
return;
|
||||
}
|
||||
|
||||
const requestThreadId = threadIdRef.current;
|
||||
loadingRunIdRef.current = run.run_id;
|
||||
const result: { data: RunMessage[]; hasMore: boolean } = await fetch(
|
||||
`${getBackendBaseURL()}/api/threads/${encodeURIComponent(requestThreadId)}/runs/${encodeURIComponent(run.run_id)}/messages`,
|
||||
{
|
||||
method: "GET",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
credentials: "include",
|
||||
},
|
||||
credentials: "include",
|
||||
},
|
||||
).then((res) => {
|
||||
return res.json();
|
||||
});
|
||||
const _messages = result.data
|
||||
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||
.map((m) => m.content);
|
||||
setMessages((prev) => [..._messages, ...prev]);
|
||||
indexRef.current -= 1;
|
||||
).then((res) => {
|
||||
return res.json();
|
||||
});
|
||||
const _messages = result.data
|
||||
.filter((m) => !m.metadata.caller?.startsWith("middleware:"))
|
||||
.map((m) => m.content);
|
||||
if (threadIdRef.current !== requestThreadId) {
|
||||
return;
|
||||
}
|
||||
setMessages((prev) =>
|
||||
dedupeMessagesByIdentity([..._messages, ...prev]),
|
||||
);
|
||||
loadedRunIdsRef.current.add(run.run_id);
|
||||
indexRef.current = findLatestUnloadedRunIndex(
|
||||
runsRef.current,
|
||||
loadedRunIdsRef.current,
|
||||
);
|
||||
} while (pendingLoadRef.current);
|
||||
} catch (err) {
|
||||
console.error(err);
|
||||
} finally {
|
||||
loadingRef.current = false;
|
||||
loadingRunIdRef.current = null;
|
||||
setLoading(false);
|
||||
}
|
||||
}, []);
|
||||
useEffect(() => {
|
||||
const threadChanged = threadIdRef.current !== threadId;
|
||||
threadIdRef.current = threadId;
|
||||
|
||||
if (threadChanged) {
|
||||
runsRef.current = [];
|
||||
indexRef.current = -1;
|
||||
pendingLoadRef.current = false;
|
||||
loadingRunIdRef.current = null;
|
||||
loadedRunIdsRef.current = new Set();
|
||||
loadingRef.current = false;
|
||||
setLoading(false);
|
||||
setMessages([]);
|
||||
}
|
||||
|
||||
if (runs.data && runs.data.length > 0) {
|
||||
runsRef.current = runs.data ?? [];
|
||||
indexRef.current = runs.data.length - 1;
|
||||
indexRef.current = findLatestUnloadedRunIndex(
|
||||
runs.data,
|
||||
loadedRunIdsRef.current,
|
||||
);
|
||||
}
|
||||
loadMessages().catch(() => {
|
||||
toast.error("Failed to load thread history.");
|
||||
@@ -677,7 +770,7 @@ export function useThreadHistory(threadId: string) {
|
||||
|
||||
const appendMessages = useCallback((_messages: Message[]) => {
|
||||
setMessages((prev) => {
|
||||
return [...prev, ..._messages];
|
||||
return dedupeMessagesByIdentity([...prev, ..._messages]);
|
||||
});
|
||||
}, []);
|
||||
const hasMore = indexRef.current >= 0 || !runs.data;
|
||||
|
||||
@@ -48,4 +48,66 @@ test.describe("Chat workspace", () => {
|
||||
timeout: 10_000,
|
||||
});
|
||||
});
|
||||
|
||||
test("keeps attachments visible while upload submit is pending", async ({
|
||||
page,
|
||||
}) => {
|
||||
let releaseUpload!: () => void;
|
||||
const uploadCanFinish = new Promise<void>((resolve) => {
|
||||
releaseUpload = resolve;
|
||||
});
|
||||
let uploadStarted!: () => void;
|
||||
const uploadStartedPromise = new Promise<void>((resolve) => {
|
||||
uploadStarted = resolve;
|
||||
});
|
||||
|
||||
await page.route("**/api/threads/*/uploads", async (route) => {
|
||||
uploadStarted();
|
||||
await uploadCanFinish;
|
||||
return route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
success: true,
|
||||
message: "Uploaded",
|
||||
files: [
|
||||
{
|
||||
filename: "report.docx",
|
||||
size: 12,
|
||||
path: "report.docx",
|
||||
virtual_path: "/mnt/user-data/uploads/report.docx",
|
||||
artifact_url: "/api/threads/test/uploads/report.docx",
|
||||
extension: ".docx",
|
||||
},
|
||||
],
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
await page.goto("/workspace/chats/new");
|
||||
|
||||
const textarea = page.getByPlaceholder(/how can i assist you/i);
|
||||
await expect(textarea).toBeVisible({ timeout: 15_000 });
|
||||
const promptForm = page.locator("form").filter({ has: textarea });
|
||||
|
||||
await page.getByLabel("Upload files").setInputFiles({
|
||||
name: "report.docx",
|
||||
mimeType:
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
buffer: Buffer.from("fake docx"),
|
||||
});
|
||||
await expect(promptForm.getByText("report.docx")).toBeVisible();
|
||||
|
||||
await textarea.fill("Summarize this document");
|
||||
await textarea.press("Enter");
|
||||
|
||||
await uploadStartedPromise;
|
||||
await expect(promptForm.getByText("report.docx")).toBeVisible();
|
||||
|
||||
releaseUpload();
|
||||
await expect(page.getByText("Hello from DeerFlow!")).toBeVisible({
|
||||
timeout: 10_000,
|
||||
});
|
||||
await expect(promptForm.getByText("report.docx")).toBeHidden();
|
||||
});
|
||||
});
|
||||
|
||||
@@ -63,3 +63,37 @@ test("aggregates token usage messages once per assistant turn", () => {
|
||||
),
|
||||
).toEqual([null, null, ["ai-1", "ai-2"], null, ["ai-3"]]);
|
||||
});
|
||||
|
||||
test("hides internal todo reminder messages from message groups", () => {
|
||||
const messages = [
|
||||
{
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "Audit the middleware",
|
||||
},
|
||||
{
|
||||
id: "todo-reminder-1",
|
||||
type: "human",
|
||||
name: "todo_completion_reminder",
|
||||
content: "<system_reminder>finish todos</system_reminder>",
|
||||
},
|
||||
{
|
||||
id: "todo-reminder-2",
|
||||
type: "human",
|
||||
name: "todo_reminder",
|
||||
content: "<system_reminder>remember todos</system_reminder>",
|
||||
},
|
||||
{
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "Done",
|
||||
},
|
||||
] as Message[];
|
||||
|
||||
const groups = getMessageGroups(messages);
|
||||
|
||||
expect(groups.map((group) => group.type)).toEqual(["human", "assistant"]);
|
||||
expect(
|
||||
groups.flatMap((group) => group.messages).map((message) => message.id),
|
||||
).toEqual(["human-1", "ai-1"]);
|
||||
});
|
||||
|
||||
@@ -0,0 +1,64 @@
|
||||
import type { Message } from "@langchain/langgraph-sdk";
|
||||
import { expect, test } from "vitest";
|
||||
|
||||
import { mergeMessages } from "@/core/threads/hooks";
|
||||
|
||||
test("mergeMessages removes duplicate messages already present in history", () => {
|
||||
const human = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "Design an agent",
|
||||
} as Message;
|
||||
const ai = {
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "Let's design it.",
|
||||
} as Message;
|
||||
|
||||
expect(mergeMessages([human, ai, human, ai], [], [])).toEqual([human, ai]);
|
||||
});
|
||||
|
||||
test("mergeMessages lets live thread messages replace overlapping history", () => {
|
||||
const oldHuman = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "old",
|
||||
} as Message;
|
||||
const liveHuman = {
|
||||
id: "human-1",
|
||||
type: "human",
|
||||
content: "live",
|
||||
} as Message;
|
||||
const oldAi = {
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "old",
|
||||
} as Message;
|
||||
const liveAi = {
|
||||
id: "ai-1",
|
||||
type: "ai",
|
||||
content: "live",
|
||||
} as Message;
|
||||
|
||||
expect(mergeMessages([oldHuman, oldAi], [liveHuman, liveAi], [])).toEqual([
|
||||
liveHuman,
|
||||
liveAi,
|
||||
]);
|
||||
});
|
||||
|
||||
test("mergeMessages deduplicates tool messages by tool_call_id", () => {
|
||||
const oldTool = {
|
||||
id: "tool-message-old",
|
||||
type: "tool",
|
||||
tool_call_id: "call-1",
|
||||
content: "old",
|
||||
} as Message;
|
||||
const liveTool = {
|
||||
id: "tool-message-live",
|
||||
type: "tool",
|
||||
tool_call_id: "call-1",
|
||||
content: "live",
|
||||
} as Message;
|
||||
|
||||
expect(mergeMessages([oldTool], [liveTool], [])).toEqual([liveTool]);
|
||||
});
|
||||
@@ -72,6 +72,7 @@ def find_config_file() -> Path | None:
|
||||
|
||||
|
||||
_SECTION_RE = re.compile(r"^([A-Za-z_][\w-]*)\s*:\s*$")
|
||||
_INDENTED_SECTION_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*$")
|
||||
_KEY_RE = re.compile(r"^\s+([A-Za-z_][\w-]*)\s*:\s*(\S.*?)\s*$")
|
||||
|
||||
|
||||
@@ -141,6 +142,84 @@ def section_value(lines: list[str], section: str, key: str) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def nested_section_value(lines: list[str], section_path: str, key: str) -> str | None:
|
||||
"""Return the value of a nested YAML key like ``channels.discord.enabled``.
|
||||
|
||||
Handles two levels of nesting:
|
||||
channels:
|
||||
discord:
|
||||
enabled: true
|
||||
"""
|
||||
parts = section_path.split(".")
|
||||
if len(parts) != 2:
|
||||
return None
|
||||
parent_section, child_section = parts
|
||||
|
||||
inside_parent = False
|
||||
inside_child = False
|
||||
parent_indent: int | None = None
|
||||
child_indent: int | None = None
|
||||
|
||||
for raw in lines:
|
||||
line = _strip_comment(raw)
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
stripped = line.lstrip()
|
||||
indent = len(line) - len(stripped)
|
||||
|
||||
# Top-level section match
|
||||
sect_match = _SECTION_RE.match(line)
|
||||
if sect_match:
|
||||
if indent == 0:
|
||||
inside_parent = sect_match.group(1) == parent_section
|
||||
inside_child = False
|
||||
parent_indent = None
|
||||
child_indent = None
|
||||
continue
|
||||
|
||||
if not inside_parent:
|
||||
continue
|
||||
|
||||
# Track parent indent from first child
|
||||
if parent_indent is None and indent > 0:
|
||||
parent_indent = indent
|
||||
|
||||
# If indent goes back to 0, we left the parent section
|
||||
if indent == 0:
|
||||
inside_parent = False
|
||||
inside_child = False
|
||||
continue
|
||||
|
||||
# Check if we're at the parent's child level (subsection)
|
||||
if parent_indent is not None and indent == parent_indent:
|
||||
# This could be a subsection or a direct key of parent
|
||||
sub_match = _INDENTED_SECTION_RE.match(line)
|
||||
if sub_match and sub_match.group(1) == child_section:
|
||||
inside_child = True
|
||||
child_indent = None
|
||||
continue
|
||||
else:
|
||||
inside_child = False
|
||||
continue
|
||||
|
||||
if not inside_child:
|
||||
continue
|
||||
|
||||
# We're inside the subsection — track child indent
|
||||
if child_indent is None and indent > (parent_indent or 0):
|
||||
child_indent = indent
|
||||
|
||||
if child_indent is not None and indent != child_indent:
|
||||
continue
|
||||
|
||||
key_match = _KEY_RE.match(line)
|
||||
if key_match and key_match.group(1) == key:
|
||||
return _unquote(key_match.group(2).strip())
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def detect_from_config(path: Path) -> list[str]:
|
||||
try:
|
||||
text = path.read_text(encoding="utf-8", errors="replace")
|
||||
@@ -152,6 +231,8 @@ def detect_from_config(path: Path) -> list[str]:
|
||||
extras.add("postgres")
|
||||
if (section_value(lines, "checkpointer", "type") or "").lower() == "postgres":
|
||||
extras.add("postgres")
|
||||
if (nested_section_value(lines, "channels.discord", "enabled") or "").lower() == "true":
|
||||
extras.add("discord")
|
||||
return sorted(extras)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user