mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 09:55:59 +00:00
Support all integrated IM channel connections
This commit is contained in:
@@ -8,8 +8,10 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any, cast
|
||||
|
||||
from app.channels.base import Channel
|
||||
from app.channels.commands import is_known_channel_command
|
||||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||||
from app.channels.connection_identity import attach_connection_identity
|
||||
from app.channels.message_bus import (
|
||||
InboundMessage,
|
||||
InboundMessageType,
|
||||
MessageBus,
|
||||
OutboundMessage,
|
||||
@@ -29,6 +31,7 @@ class WeComChannel(Channel):
|
||||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||||
self._ws_stream_ids: dict[str, str] = {}
|
||||
self._working_message = "Working on it..."
|
||||
self._connection_repo = config.get("connection_repo")
|
||||
|
||||
@property
|
||||
def supports_streaming(self) -> bool:
|
||||
@@ -271,6 +274,16 @@ class WeComChannel(Channel):
|
||||
|
||||
user_id = (body.get("from") or {}).get("userid")
|
||||
|
||||
connect_code = extract_connect_code(text)
|
||||
if connect_code and self._connection_repo is not None:
|
||||
handled = await self._bind_connection_from_connect_code(
|
||||
frame=frame,
|
||||
user_id=str(user_id or ""),
|
||||
code=connect_code,
|
||||
)
|
||||
if handled:
|
||||
return
|
||||
|
||||
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
||||
inbound = self._make_inbound(
|
||||
chat_id=user_id, # keep user's conversation in memory
|
||||
@@ -292,8 +305,52 @@ class WeComChannel(Channel):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
await self.bus.publish_inbound(inbound)
|
||||
|
||||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||||
return await attach_connection_identity(
|
||||
inbound,
|
||||
repo=self._connection_repo,
|
||||
provider="wecom",
|
||||
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
|
||||
fallback_without_workspace=True,
|
||||
)
|
||||
|
||||
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
|
||||
if self._connection_repo is None or not code:
|
||||
return False
|
||||
|
||||
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
|
||||
if state is None:
|
||||
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
|
||||
return True
|
||||
|
||||
if not user_id:
|
||||
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
|
||||
return True
|
||||
|
||||
body = frame.get("body", {}) or {}
|
||||
workspace_id = str(body.get("aibotid") or "") or None
|
||||
await self._connection_repo.upsert_connection(
|
||||
owner_user_id=state["owner_user_id"],
|
||||
provider="wecom",
|
||||
external_account_id=user_id,
|
||||
workspace_id=workspace_id,
|
||||
metadata={
|
||||
"aibotid": workspace_id,
|
||||
"chattype": body.get("chattype"),
|
||||
},
|
||||
status="connected",
|
||||
)
|
||||
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
|
||||
return True
|
||||
|
||||
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
|
||||
|
||||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||||
if not self._ws_client:
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user