mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Support all integrated IM channel connections
This commit is contained in:
@@ -14,7 +14,8 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -136,6 +137,7 @@ class DingTalkChannel(Channel):
|
||||
self._incoming_messages: dict[str, Any] = {}
|
||||
self._incoming_messages_lock = threading.Lock()
|
||||
self._card_repliers: dict[str, Any] = {}
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -395,6 +397,24 @@ class DingTalkChannel(Channel):
|
||||
text[:100],
|
||||
)
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
conversation_type=conversation_type,
|
||||
sender_staff_id=sender_staff_id,
|
||||
sender_nick=sender_nick,
|
||||
conversation_id=conversation_id,
|
||||
code=connect_code,
|
||||
),
|
||||
self._main_loop,
|
||||
)
|
||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||
else:
|
||||
logger.warning("[DingTalk] main loop not running, cannot bind channel connection")
|
||||
return
|
||||
|
||||
if _is_dingtalk_command(text):
|
||||
msg_type = InboundMessageType.COMMAND
|
||||
else:
|
||||
@@ -450,11 +470,95 @@ class DingTalkChannel(Channel):
|
||||
return ""
|
||||
|
||||
async def _prepare_inbound(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
# Running reply must finish before publish_inbound so AI card tracks are
|
||||
# registered before the manager emits streaming outbounds.
|
||||
await self._send_running_reply(chat_id, inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
@staticmethod
|
||||
def _connection_workspace_id(conversation_type: str, conversation_id: str) -> str | None:
|
||||
if conversation_type == _CONVERSATION_TYPE_GROUP and conversation_id:
|
||||
return conversation_id
|
||||
return None
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
conversation_type = str(inbound.metadata.get("conversation_type") or _CONVERSATION_TYPE_P2P)
|
||||
conversation_id = str(inbound.metadata.get("conversation_id") or "")
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="dingtalk",
|
||||
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(
|
||||
self,
|
||||
*,
|
||||
conversation_type: str,
|
||||
sender_staff_id: str,
|
||||
sender_nick: str,
|
||||
conversation_id: str,
|
||||
code: str,
|
||||
) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="dingtalk", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(
|
||||
conversation_type,
|
||||
sender_staff_id,
|
||||
conversation_id,
|
||||
"DingTalk connection code is invalid or expired.",
|
||||
)
|
||||
return True
|
||||
|
||||
if not sender_staff_id:
|
||||
await self._send_connection_reply(
|
||||
conversation_type,
|
||||
sender_staff_id,
|
||||
conversation_id,
|
||||
"DingTalk connection could not be completed from this message.",
|
||||
)
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="dingtalk",
|
||||
external_account_id=sender_staff_id,
|
||||
external_account_name=sender_nick or None,
|
||||
workspace_id=self._connection_workspace_id(conversation_type, conversation_id),
|
||||
metadata={
|
||||
"conversation_type": conversation_type,
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(
|
||||
conversation_type,
|
||||
sender_staff_id,
|
||||
conversation_id,
|
||||
"DingTalk connected to DeerFlow.",
|
||||
)
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(
|
||||
self,
|
||||
conversation_type: str,
|
||||
sender_staff_id: str,
|
||||
conversation_id: str,
|
||||
text: str,
|
||||
) -> None:
|
||||
robot_code = self._client_id
|
||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
||||
if conversation_id:
|
||||
await self._send_text_message_to_group(robot_code, conversation_id, text)
|
||||
return
|
||||
if sender_staff_id:
|
||||
await self._send_text_message_to_user(robot_code, sender_staff_id, text)
|
||||
|
||||
async def _send_running_reply(self, chat_id: str, inbound: InboundMessage) -> None:
|
||||
conversation_type = inbound.metadata.get("conversation_type", _CONVERSATION_TYPE_P2P)
|
||||
sender_staff_id = inbound.metadata.get("sender_staff_id", "")
|
||||
|
||||
@@ -11,7 +11,8 @@ import time
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import (
|
||||
PENDING_CLARIFICATION_METADATA_KEY,
|
||||
RESOLVED_FROM_PENDING_CLARIFICATION_METADATA_KEY,
|
||||
@@ -71,6 +72,7 @@ class FeishuChannel(Channel):
|
||||
self._CreateImageRequestBody = None
|
||||
self._GetMessageResourceRequest = None
|
||||
self._thread_lock = threading.Lock()
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@staticmethod
|
||||
def _non_empty_str(value: Any) -> str | None:
|
||||
@@ -726,11 +728,47 @@ class FeishuChannel(Channel):
|
||||
|
||||
async def _prepare_inbound(self, msg_id: str, inbound) -> None:
|
||||
"""Kick off Feishu side effects without delaying inbound dispatch."""
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
reaction_task = asyncio.create_task(self._add_reaction(msg_id, "OK"))
|
||||
self._track_background_task(reaction_task, name="add_reaction", msg_id=msg_id)
|
||||
self._ensure_running_card_started(msg_id)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="feishu",
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, message_id: str, chat_id: str, user_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="feishu", state=code)
|
||||
if state is None:
|
||||
await self._reply_card(message_id, "Feishu connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not user_id or not chat_id:
|
||||
await self._reply_card(message_id, "Feishu connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="feishu",
|
||||
external_account_id=user_id,
|
||||
workspace_id=chat_id,
|
||||
metadata={
|
||||
"chat_id": chat_id,
|
||||
"message_id": message_id,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._reply_card(message_id, "Feishu connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
def _on_message(self, event) -> None:
|
||||
"""Called by lark-oapi when a message is received (runs in lark thread)."""
|
||||
try:
|
||||
@@ -819,6 +857,23 @@ class FeishuChannel(Channel):
|
||||
logger.info("[Feishu] empty text, ignoring message")
|
||||
return
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
if self._main_loop and self._main_loop.is_running():
|
||||
fut = asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
message_id=msg_id,
|
||||
chat_id=chat_id,
|
||||
user_id=sender_id,
|
||||
code=connect_code,
|
||||
),
|
||||
self._main_loop,
|
||||
)
|
||||
fut.add_done_callback(lambda f, mid=msg_id: self._log_future_error(f, "bind_connection", mid))
|
||||
else:
|
||||
logger.warning("[Feishu] main loop not running, cannot bind channel connection")
|
||||
return
|
||||
|
||||
# Only treat known slash commands as commands; absolute paths and
|
||||
# other slash-prefixed text should be handled as normal chat.
|
||||
if _is_feishu_command(text):
|
||||
|
||||
@@ -22,8 +22,9 @@ from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.message_bus import InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -253,6 +254,7 @@ class WechatChannel(Channel):
|
||||
self._state_dir = self._resolve_state_dir(config.get("state_dir"))
|
||||
self._cursor_path = self._state_dir / "wechat-getupdates.json" if self._state_dir else None
|
||||
self._auth_path = self._state_dir / "wechat-auth.json" if self._state_dir else None
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
self._load_state()
|
||||
|
||||
async def start(self) -> None:
|
||||
@@ -617,6 +619,16 @@ class WechatChannel(Channel):
|
||||
if thread_ts:
|
||||
self._context_tokens_by_thread[thread_ts] = context_token
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
handled = await self._bind_connection_from_connect_code(
|
||||
chat_id=chat_id,
|
||||
context_token=context_token,
|
||||
code=connect_code,
|
||||
)
|
||||
if handled:
|
||||
return
|
||||
|
||||
inbound = self._make_inbound(
|
||||
chat_id=chat_id,
|
||||
user_id=chat_id,
|
||||
@@ -632,8 +644,54 @@ class WechatChannel(Channel):
|
||||
},
|
||||
)
|
||||
inbound.topic_id = None
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="wechat",
|
||||
workspace_id=inbound.chat_id,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, chat_id: str, context_token: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="wechat", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(chat_id, context_token, "WeChat connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not chat_id:
|
||||
await self._send_connection_reply(chat_id, context_token, "WeChat connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="wechat",
|
||||
external_account_id=chat_id,
|
||||
workspace_id=chat_id,
|
||||
metadata={
|
||||
"context_token": context_token,
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(chat_id, context_token, "WeChat connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(self, chat_id: str, context_token: str, text: str) -> None:
|
||||
if not context_token:
|
||||
return
|
||||
await self._send_text_message(
|
||||
chat_id=chat_id,
|
||||
context_token=context_token,
|
||||
text=text,
|
||||
client_id_prefix="deerflow-connect",
|
||||
max_retries=1,
|
||||
)
|
||||
|
||||
async def _ensure_authenticated(self) -> bool:
|
||||
async with self._auth_lock:
|
||||
if self._bot_token:
|
||||
|
||||
@@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import (
|
||||
InboundMessage,
|
||||
InboundMessageType,
|
||||
MessageBus,
|
||||
OutboundMessage,
|
||||
@@ -29,6 +31,7 @@ class WeComChannel(Channel):
|
||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||
self._ws_stream_ids: dict[str, str] = {}
|
||||
self._working_message = "Working on it..."
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -271,6 +274,16 @@ class WeComChannel(Channel):
|
||||
|
||||
user_id = (body.get("from") or {}).get("userid")
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
handled = await self._bind_connection_from_connect_code(
|
||||
frame=frame,
|
||||
user_id=str(user_id or ""),
|
||||
code=connect_code,
|
||||
)
|
||||
if handled:
|
||||
return
|
||||
|
||||
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
chat_id=user_id, # keep user's conversation in memory
|
||||
@@ -292,8 +305,52 @@ class WeComChannel(Channel):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="wecom",
|
||||
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not user_id:
|
||||
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
body = frame.get("body", {}) or {}
|
||||
workspace_id = str(body.get("aibotid") or "") or None
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="wecom",
|
||||
external_account_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
metadata={
|
||||
"aibotid": workspace_id,
|
||||
"chattype": body.get("chattype"),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
|
||||
|
||||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user