mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-10 17:35:57 +00:00
16391e35ab
* support slash skill activation * format slash skill activation * Preserve slash skill activation with uploads * Address slash skill review feedback * Address slash skill follow-up review * Fix lazy slash skill storage resolution * Keep slash skill activation out of system prompt * Address slash skill review issues * fix: harden slash skill command handling * feat(frontend): add slash skill autocomplete * fix: address slash skill review feedback * fix: preserve slash skill text for IM uploads
555 lines
24 KiB
Python
555 lines
24 KiB
Python
"""Discord channel integration using discord.py."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import threading
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from app.channels.base import Channel
|
|
from app.channels.commands import is_known_channel_command
|
|
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_DISCORD_MAX_MESSAGE_LEN = 2000
|
|
|
|
|
|
class DiscordChannel(Channel):
|
|
"""Discord bot channel.
|
|
|
|
Configuration keys (in ``config.yaml`` under ``channels.discord``):
|
|
- ``bot_token``: Discord Bot token.
|
|
- ``allowed_guilds``: (optional) List of allowed Discord guild IDs. Empty = allow all.
|
|
- ``mention_only``: (optional) If true, only respond when the bot is mentioned.
|
|
- ``allowed_channels``: (optional) List of channel IDs where messages are always accepted
|
|
(even when mention_only is true). Use for channels where you want the bot to respond
|
|
without mentions. Empty = mention_only applies everywhere.
|
|
- ``thread_mode``: (optional) If true, group a channel conversation into a thread.
|
|
Default: same as ``mention_only``.
|
|
"""
|
|
|
|
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
|
super().__init__(name="discord", bus=bus, config=config)
|
|
self._bot_token = str(config.get("bot_token", "")).strip()
|
|
self._allowed_guilds: set[int] = set()
|
|
for guild_id in config.get("allowed_guilds", []):
|
|
try:
|
|
self._allowed_guilds.add(int(guild_id))
|
|
except (TypeError, ValueError):
|
|
continue
|
|
self._mention_only: bool = bool(config.get("mention_only", False))
|
|
self._thread_mode: bool = config.get("thread_mode", self._mention_only)
|
|
self._allowed_channels: set[str] = set()
|
|
for channel_id in config.get("allowed_channels", []):
|
|
self._allowed_channels.add(str(channel_id))
|
|
|
|
# Session tracking: channel_id -> Discord thread_id (in-memory, persisted to JSON).
|
|
# Uses a dedicated JSON file separate from ChannelStore, which maps IM
|
|
# conversations to DeerFlow thread IDs — a different concern.
|
|
self._active_threads: dict[str, str] = {}
|
|
# Reverse-lookup set for O(1) thread ID checks (avoids O(n) scan of _active_threads.values()).
|
|
self._active_thread_ids: set[str] = set()
|
|
# Lock protecting _active_threads and the JSON file from concurrent access.
|
|
# _run_client (Discord loop thread) and the main thread both read/write.
|
|
self._thread_store_lock = threading.Lock()
|
|
store = config.get("channel_store")
|
|
if store is not None:
|
|
self._thread_store_path = store._path.parent / "discord_threads.json"
|
|
else:
|
|
self._thread_store_path = Path.home() / ".deer-flow" / "channels" / "discord_threads.json"
|
|
|
|
# Typing indicator management
|
|
self._typing_tasks: dict[str, asyncio.Task] = {}
|
|
|
|
self._client = None
|
|
self._thread: threading.Thread | None = None
|
|
self._discord_loop: asyncio.AbstractEventLoop | None = None
|
|
self._main_loop: asyncio.AbstractEventLoop | None = None
|
|
self._discord_module = None
|
|
|
|
async def start(self) -> None:
|
|
if self._running:
|
|
return
|
|
|
|
try:
|
|
import discord
|
|
except ImportError:
|
|
logger.error("discord.py is not installed. Install it with: uv add discord.py")
|
|
return
|
|
|
|
if not self._bot_token:
|
|
logger.error("Discord channel requires bot_token")
|
|
return
|
|
|
|
intents = discord.Intents.default()
|
|
intents.messages = True
|
|
intents.guilds = True
|
|
intents.message_content = True
|
|
|
|
client = discord.Client(
|
|
intents=intents,
|
|
allowed_mentions=discord.AllowedMentions.none(),
|
|
)
|
|
self._client = client
|
|
self._discord_module = discord
|
|
self._main_loop = asyncio.get_event_loop()
|
|
|
|
@client.event
|
|
async def on_message(message) -> None:
|
|
await self._on_message(message)
|
|
|
|
self._running = True
|
|
self.bus.subscribe_outbound(self._on_outbound)
|
|
|
|
self._thread = threading.Thread(target=self._run_client, daemon=True)
|
|
self._thread.start()
|
|
self._load_active_threads()
|
|
logger.info("Discord channel started")
|
|
|
|
def _load_active_threads(self) -> None:
|
|
"""Restore Discord thread mappings from the dedicated JSON file on startup."""
|
|
with self._thread_store_lock:
|
|
try:
|
|
if not self._thread_store_path.exists():
|
|
logger.debug("[Discord] no thread mappings file at %s", self._thread_store_path)
|
|
return
|
|
data = json.loads(self._thread_store_path.read_text())
|
|
self._active_threads.clear()
|
|
self._active_thread_ids.clear()
|
|
for channel_id, thread_id in data.items():
|
|
self._active_threads[channel_id] = thread_id
|
|
self._active_thread_ids.add(thread_id)
|
|
if self._active_threads:
|
|
logger.info("[Discord] restored %d thread mappings from %s", len(self._active_threads), self._thread_store_path)
|
|
except Exception:
|
|
logger.exception("[Discord] failed to load thread mappings")
|
|
|
|
def _save_thread(self, channel_id: str, thread_id: str) -> None:
|
|
"""Persist a Discord thread mapping to the dedicated JSON file."""
|
|
with self._thread_store_lock:
|
|
try:
|
|
data: dict[str, str] = {}
|
|
if self._thread_store_path.exists():
|
|
data = json.loads(self._thread_store_path.read_text())
|
|
old_id = data.get(channel_id)
|
|
data[channel_id] = thread_id
|
|
# Update reverse-lookup set
|
|
if old_id:
|
|
self._active_thread_ids.discard(old_id)
|
|
self._active_thread_ids.add(thread_id)
|
|
self._thread_store_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._thread_store_path.write_text(json.dumps(data, indent=2))
|
|
except Exception:
|
|
logger.exception("[Discord] failed to save thread mapping for channel %s", channel_id)
|
|
|
|
async def stop(self) -> None:
|
|
self._running = False
|
|
self.bus.unsubscribe_outbound(self._on_outbound)
|
|
|
|
# Cancel all active typing indicator tasks
|
|
for target_id, task in list(self._typing_tasks.items()):
|
|
if not task.done():
|
|
task.cancel()
|
|
logger.debug("[Discord] cancelled typing task for target %s", target_id)
|
|
self._typing_tasks.clear()
|
|
|
|
if self._client and self._discord_loop and self._discord_loop.is_running():
|
|
close_future = asyncio.run_coroutine_threadsafe(self._client.close(), self._discord_loop)
|
|
try:
|
|
await asyncio.wait_for(asyncio.wrap_future(close_future), timeout=10)
|
|
except TimeoutError:
|
|
logger.warning("[Discord] client close timed out after 10s")
|
|
except Exception:
|
|
logger.exception("[Discord] error while closing client")
|
|
|
|
if self._thread:
|
|
self._thread.join(timeout=10)
|
|
self._thread = None
|
|
|
|
self._client = None
|
|
self._discord_loop = None
|
|
self._discord_module = None
|
|
logger.info("Discord channel stopped")
|
|
|
|
async def send(self, msg: OutboundMessage) -> None:
|
|
# Stop typing indicator once we're sending the response
|
|
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
|
await asyncio.wrap_future(stop_future)
|
|
|
|
target = await self._resolve_target(msg)
|
|
if target is None:
|
|
logger.error("[Discord] target not found for chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
|
return
|
|
|
|
text = msg.text or ""
|
|
for chunk in self._split_text(text):
|
|
send_future = asyncio.run_coroutine_threadsafe(target.send(chunk), self._discord_loop)
|
|
await asyncio.wrap_future(send_future)
|
|
|
|
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
|
stop_future = asyncio.run_coroutine_threadsafe(self._stop_typing(msg.chat_id, msg.thread_ts), self._discord_loop)
|
|
await asyncio.wrap_future(stop_future)
|
|
|
|
target = await self._resolve_target(msg)
|
|
if target is None:
|
|
logger.error("[Discord] target not found for file upload chat_id=%s thread_ts=%s", msg.chat_id, msg.thread_ts)
|
|
return False
|
|
|
|
if self._discord_module is None:
|
|
return False
|
|
|
|
try:
|
|
fp = open(str(attachment.actual_path), "rb") # noqa: SIM115
|
|
file = self._discord_module.File(fp, filename=attachment.filename)
|
|
send_future = asyncio.run_coroutine_threadsafe(target.send(file=file), self._discord_loop)
|
|
await asyncio.wrap_future(send_future)
|
|
logger.info("[Discord] file uploaded: %s", attachment.filename)
|
|
return True
|
|
except Exception:
|
|
logger.exception("[Discord] failed to upload file: %s", attachment.filename)
|
|
return False
|
|
|
|
async def _start_typing(self, channel, chat_id: str, thread_ts: str | None = None) -> None:
|
|
"""Starts a loop to send periodic typing indicators."""
|
|
target_id = thread_ts or chat_id
|
|
if target_id in self._typing_tasks:
|
|
return # Already typing for this target
|
|
|
|
async def _typing_loop():
|
|
try:
|
|
while True:
|
|
try:
|
|
await channel.trigger_typing()
|
|
except Exception:
|
|
pass
|
|
await asyncio.sleep(10)
|
|
except asyncio.CancelledError:
|
|
pass
|
|
|
|
task = asyncio.create_task(_typing_loop())
|
|
self._typing_tasks[target_id] = task
|
|
|
|
async def _stop_typing(self, chat_id: str, thread_ts: str | None = None) -> None:
|
|
"""Stops the typing loop for a specific target."""
|
|
target_id = thread_ts or chat_id
|
|
task = self._typing_tasks.pop(target_id, None)
|
|
if task and not task.done():
|
|
task.cancel()
|
|
logger.debug("[Discord] stopped typing indicator for target %s", target_id)
|
|
|
|
async def _add_reaction(self, message) -> None:
|
|
"""Add a checkmark reaction to acknowledge the message was received."""
|
|
try:
|
|
await message.add_reaction("✅")
|
|
except Exception:
|
|
logger.debug("[Discord] failed to add reaction to message %s", message.id, exc_info=True)
|
|
|
|
async def _on_message(self, message) -> None:
|
|
if not self._running or not self._client:
|
|
return
|
|
|
|
if message.author.bot:
|
|
return
|
|
|
|
if self._client.user and message.author.id == self._client.user.id:
|
|
return
|
|
|
|
guild = message.guild
|
|
if self._allowed_guilds:
|
|
if guild is None or guild.id not in self._allowed_guilds:
|
|
return
|
|
|
|
text = (message.content or "").strip()
|
|
if not text:
|
|
return
|
|
|
|
if self._discord_module is None:
|
|
return
|
|
|
|
# Determine whether the bot is mentioned in this message
|
|
user = self._client.user if self._client else None
|
|
if user:
|
|
bot_mention = user.mention # <@ID>
|
|
alt_mention = f"<@!{user.id}>" # <@!ID> (ping variant)
|
|
standard_mention = f"<@{user.id}>"
|
|
else:
|
|
bot_mention = None
|
|
alt_mention = None
|
|
standard_mention = ""
|
|
has_mention = (bot_mention and bot_mention in message.content) or (alt_mention and alt_mention in message.content) or (standard_mention and standard_mention in message.content)
|
|
|
|
# Strip mention from text for processing
|
|
if has_mention:
|
|
text = text.replace(bot_mention or "", "").replace(alt_mention or "", "").replace(standard_mention or "", "").strip()
|
|
# Don't return early if text is empty — still process the mention (e.g., create thread)
|
|
|
|
# --- Determine thread/channel routing and typing target ---
|
|
thread_id = None
|
|
chat_id = None
|
|
typing_target = None # The Discord object to type into
|
|
|
|
if isinstance(message.channel, self._discord_module.Thread):
|
|
# --- Message already inside a thread ---
|
|
thread_obj = message.channel
|
|
thread_id = str(thread_obj.id)
|
|
chat_id = str(thread_obj.parent_id or thread_obj.id)
|
|
typing_target = thread_obj
|
|
|
|
# If this is a known active thread, process normally
|
|
if thread_id in self._active_thread_ids:
|
|
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
|
inbound = self._make_inbound(
|
|
chat_id=chat_id,
|
|
user_id=str(message.author.id),
|
|
text=text,
|
|
msg_type=msg_type,
|
|
thread_ts=thread_id,
|
|
metadata={
|
|
"guild_id": str(guild.id) if guild else None,
|
|
"channel_id": str(message.channel.id),
|
|
"message_id": str(message.id),
|
|
},
|
|
)
|
|
inbound.topic_id = thread_id
|
|
self._publish(inbound)
|
|
# Start typing indicator in the thread
|
|
if typing_target:
|
|
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
|
asyncio.create_task(self._add_reaction(message))
|
|
return
|
|
|
|
# Thread not tracked (orphaned) — create new thread and handle below
|
|
logger.debug("[Discord] message in orphaned thread %s, will create new thread", thread_id)
|
|
thread_id = None
|
|
typing_target = None
|
|
|
|
# At this point we're guaranteed to be in a channel, not a thread
|
|
# (the Thread case is handled above). Apply mention_only for all
|
|
# non-thread messages — no special case needed.
|
|
channel_id = str(message.channel.id)
|
|
|
|
# Check if there's an active thread for this channel
|
|
if channel_id in self._active_threads:
|
|
# respect mention_only: if enabled, only process messages that mention the bot
|
|
# (unless the channel is in allowed_channels)
|
|
# Messages within a thread are always allowed through (continuation).
|
|
# At this code point we know the message is in a channel, not a thread
|
|
# (Thread case handled above), so always apply the check.
|
|
if self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
|
logger.debug("[Discord] skipping no-@ message in channel %s (not in thread)", channel_id)
|
|
return
|
|
# mention_only + fresh @ → create new thread instead of routing to existing one
|
|
if self._mention_only and has_mention:
|
|
thread_obj = await self._create_thread(message)
|
|
if thread_obj is not None:
|
|
target_thread_id = str(thread_obj.id)
|
|
self._active_threads[channel_id] = target_thread_id
|
|
self._save_thread(channel_id, target_thread_id)
|
|
thread_id = target_thread_id
|
|
chat_id = channel_id
|
|
typing_target = thread_obj
|
|
logger.info("[Discord] created new thread %s in channel %s on mention (replacing existing thread)", target_thread_id, channel_id)
|
|
else:
|
|
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
thread_id = channel_id
|
|
chat_id = channel_id
|
|
typing_target = message.channel
|
|
else:
|
|
# Existing session → route to the existing thread
|
|
target_thread_id = self._active_threads[channel_id]
|
|
logger.debug("[Discord] routing message in channel %s to existing thread %s", channel_id, target_thread_id)
|
|
thread_id = target_thread_id
|
|
chat_id = channel_id
|
|
typing_target = await self._get_channel_or_thread(target_thread_id)
|
|
elif self._mention_only and not has_mention and channel_id not in self._allowed_channels:
|
|
# Not mentioned and not in an allowed channel → skip
|
|
logger.debug("[Discord] skipping message without mention in channel %s", channel_id)
|
|
return
|
|
elif self._mention_only and has_mention:
|
|
# First mention in this channel → create thread
|
|
thread_obj = await self._create_thread(message)
|
|
if thread_obj is not None:
|
|
target_thread_id = str(thread_obj.id)
|
|
self._active_threads[channel_id] = target_thread_id
|
|
self._save_thread(channel_id, target_thread_id)
|
|
thread_id = target_thread_id
|
|
chat_id = channel_id
|
|
typing_target = thread_obj # Type into the new thread
|
|
logger.info("[Discord] created thread %s in channel %s for user %s", target_thread_id, channel_id, message.author.display_name)
|
|
else:
|
|
# Fallback: thread creation failed (disabled/permissions), reply in channel
|
|
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
thread_id = channel_id
|
|
chat_id = channel_id
|
|
typing_target = message.channel # Type into the channel
|
|
elif self._thread_mode:
|
|
# thread_mode but mention_only is False → create thread anyway for conversation grouping
|
|
thread_obj = await self._create_thread(message)
|
|
if thread_obj is None:
|
|
# Thread creation failed (disabled/permissions), fall back to channel replies
|
|
logger.info("[Discord] thread creation failed in channel %s, falling back to channel replies", channel_id)
|
|
thread_id = channel_id
|
|
chat_id = channel_id
|
|
typing_target = message.channel # Type into the channel
|
|
else:
|
|
target_thread_id = str(thread_obj.id)
|
|
self._active_threads[channel_id] = target_thread_id
|
|
self._save_thread(channel_id, target_thread_id)
|
|
thread_id = target_thread_id
|
|
chat_id = channel_id
|
|
typing_target = thread_obj # Type into the new thread
|
|
else:
|
|
# No threading — reply directly in channel
|
|
thread_id = channel_id
|
|
chat_id = channel_id
|
|
typing_target = message.channel # Type into the channel
|
|
|
|
msg_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
|
inbound = self._make_inbound(
|
|
chat_id=chat_id,
|
|
user_id=str(message.author.id),
|
|
text=text,
|
|
msg_type=msg_type,
|
|
thread_ts=thread_id,
|
|
metadata={
|
|
"guild_id": str(guild.id) if guild else None,
|
|
"channel_id": str(message.channel.id),
|
|
"message_id": str(message.id),
|
|
},
|
|
)
|
|
inbound.topic_id = thread_id
|
|
|
|
# Start typing indicator in the correct target (thread or channel)
|
|
if typing_target:
|
|
asyncio.create_task(self._start_typing(typing_target, chat_id, thread_id))
|
|
|
|
self._publish(inbound)
|
|
asyncio.create_task(self._add_reaction(message))
|
|
|
|
def _publish(self, inbound) -> None:
|
|
"""Publish an inbound message to the main event loop."""
|
|
if self._main_loop and self._main_loop.is_running():
|
|
future = asyncio.run_coroutine_threadsafe(self.bus.publish_inbound(inbound), self._main_loop)
|
|
future.add_done_callback(lambda f: logger.exception("[Discord] publish_inbound failed", exc_info=f.exception()) if f.exception() else None)
|
|
|
|
def _run_client(self) -> None:
|
|
self._discord_loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(self._discord_loop)
|
|
try:
|
|
self._discord_loop.run_until_complete(self._client.start(self._bot_token))
|
|
except Exception:
|
|
if self._running:
|
|
logger.exception("Discord client error")
|
|
finally:
|
|
try:
|
|
if self._client and not self._client.is_closed():
|
|
self._discord_loop.run_until_complete(self._client.close())
|
|
except Exception:
|
|
logger.exception("Error during Discord shutdown")
|
|
|
|
async def _create_thread(self, message):
|
|
try:
|
|
if self._discord_module is None:
|
|
return None
|
|
|
|
# Only TextChannel (type 0) and NewsChannel (type 10) support threads
|
|
channel_type = message.channel.type
|
|
if channel_type not in (
|
|
self._discord_module.ChannelType.text,
|
|
self._discord_module.ChannelType.news,
|
|
):
|
|
logger.info(
|
|
"[Discord] channel type %s (%s) does not support threads",
|
|
channel_type.value,
|
|
channel_type.name,
|
|
)
|
|
return None
|
|
|
|
thread_name = f"deerflow-{message.author.display_name}-{message.id}"[:100]
|
|
return await message.create_thread(name=thread_name)
|
|
except self._discord_module.errors.HTTPException as exc:
|
|
if exc.code == 50024:
|
|
logger.info(
|
|
"[Discord] cannot create thread in channel %s (error code 50024): %s",
|
|
message.channel.id,
|
|
channel_type.name if (channel_type := message.channel.type) else "unknown",
|
|
)
|
|
else:
|
|
logger.exception(
|
|
"[Discord] failed to create thread for message=%s (HTTPException %s)",
|
|
message.id,
|
|
exc.code,
|
|
)
|
|
return None
|
|
except Exception:
|
|
logger.exception("[Discord] failed to create thread for message=%s (threads may be disabled or missing permissions)", message.id)
|
|
return None
|
|
|
|
async def _resolve_target(self, msg: OutboundMessage):
|
|
if not self._client or not self._discord_loop:
|
|
return None
|
|
|
|
target_ids: list[str] = []
|
|
if msg.thread_ts:
|
|
target_ids.append(msg.thread_ts)
|
|
if msg.chat_id and msg.chat_id not in target_ids:
|
|
target_ids.append(msg.chat_id)
|
|
|
|
for raw_id in target_ids:
|
|
target = await self._get_channel_or_thread(raw_id)
|
|
if target is not None:
|
|
return target
|
|
return None
|
|
|
|
async def _get_channel_or_thread(self, raw_id: str):
|
|
if not self._client or not self._discord_loop:
|
|
return None
|
|
|
|
try:
|
|
target_id = int(raw_id)
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
get_future = asyncio.run_coroutine_threadsafe(self._fetch_channel(target_id), self._discord_loop)
|
|
try:
|
|
return await asyncio.wrap_future(get_future)
|
|
except Exception:
|
|
logger.exception("[Discord] failed to resolve target id=%s", raw_id)
|
|
return None
|
|
|
|
async def _fetch_channel(self, target_id: int):
|
|
if not self._client:
|
|
return None
|
|
|
|
channel = self._client.get_channel(target_id)
|
|
if channel is not None:
|
|
return channel
|
|
|
|
try:
|
|
return await self._client.fetch_channel(target_id)
|
|
except Exception:
|
|
return None
|
|
|
|
@staticmethod
|
|
def _split_text(text: str) -> list[str]:
|
|
if not text:
|
|
return [""]
|
|
|
|
chunks: list[str] = []
|
|
remaining = text
|
|
while len(remaining) > _DISCORD_MAX_MESSAGE_LEN:
|
|
split_at = remaining.rfind("\n", 0, _DISCORD_MAX_MESSAGE_LEN)
|
|
if split_at <= 0:
|
|
split_at = _DISCORD_MAX_MESSAGE_LEN
|
|
chunks.append(remaining[:split_at])
|
|
remaining = remaining[split_at:].lstrip("\n")
|
|
|
|
if remaining:
|
|
chunks.append(remaining)
|
|
|
|
return chunks
|