mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-18 05:25:57 +00:00
8c0830aea1
* fix(channels): add operational guardrails * make format * fix(channels): converge with #3582 to avoid merge-order conflicts Drop this PR's DingTalk INFO-log redaction and hand it to #3582, which already restructures that handler and will redact the same log there. This PR no longer touches dingtalk.py, so the two PRs can merge to main in any order without a conflict. For WeChat, drop the contested thread_ts priority reorder (review #3) and keep only what inbound dedupe needs: a server-stable message_id in the inbound metadata (message_id/msg_id, no client_id per review #6). This is a single added line inside the metadata dict, a region #3582 never touches, so it auto-merges regardless of order. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): address three correctness review findings 1. Connect-code cap was racy (willem #1): _create_state ran delete-expired, count, and insert as three separate transactions, so concurrent connect POSTs from one owner could each see count < cap and all insert past it. Add ChannelConnectionRepository.create_oauth_state_within_cap which does delete+count+insert in a single transaction serialized per (owner, provider) — Postgres via pg_advisory_xact_lock, SQLite via the write lock the leading DELETE takes — and have the router use it. 2. Inbound dedupe key fell back to "" workspace (willem #3): two workspaces delivering without team/guild/aibotid would collapse to the same key and dedupe each other's messages. _inbound_dedupe_key now fails closed (returns None) when no workspace identifier is present. 3. Dedupe key was recorded on receipt and never released on failure (ShenAC #1): a transient error (DB blip, Gateway 503) left the key in place for the full TTL, so a provider redelivery of the same message_id — exactly the retry dedupe should absorb — was silently dropped. _handle_message now releases the key in the unexpected-exception branch so redelivery can recover, while keeping record-on-receipt so retries during handling are still deduped. Tests: repo cap enforcement incl. concurrent-issuance non-leak; dedupe fail-closed; dedupe key release-on-failure redelivery recovery. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): address cleanup/efficiency and test review findings Efficiency / cleanup: - Dedupe key set drops client-generated ids (client_msg_id, client_id); keep only server-stable event_id/message_id/msg_id, which a provider's own redelivery preserves (ShenAC #6). Every provider already emits message_id. - TTL/overflow pruning of _recent_inbound_events is now O(k): switch to an OrderedDict and popitem(last=False) from the front instead of scanning all 4096 entries on every inbound (willem #4). - Log "received inbound" only after the dedupe check so a provider retrying N times no longer logs N accepts; document that manager dedupe covers the agent run/final answer, not provider ack side-effects (willem #5, ShenAC #2). - Slack drops the redundant `team_id or event.get("team")` fallback the caller already resolved (willem #6). - create_oauth_state_within_cap prunes only this owner/provider's expired codes instead of a global DELETE on every connect POST; global cleanup still runs on consume_oauth_state (willem #7). Tests: - Dedupe test uses tmp_path instead of a leaked mkdtemp, uses distinct objects per publish, and adds a negative control: a different message_id is still processed, catching over-dedupe regressions (willem #8, ShenAC #4). - Slack HTTP-mode rejection test supplies app_token so the missing-token early return can't mask the guard, giving the state assertions teeth (ShenAC #3). - count_oauth_states test pins that the active row survives, not just the count (ShenAC #5). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * make format --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
472 lines
18 KiB
Python
472 lines
18 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import base64
|
||
import hashlib
|
||
import logging
|
||
from collections.abc import Awaitable, Callable
|
||
from typing import Any, cast
|
||
|
||
from app.channels.base import Channel
|
||
from app.channels.commands import extract_connect_code, is_known_channel_command
|
||
from app.channels.connection_identity import attach_connection_identity
|
||
from app.channels.message_bus import (
|
||
InboundMessage,
|
||
InboundMessageType,
|
||
MessageBus,
|
||
OutboundMessage,
|
||
ResolvedAttachment,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class WeComChannel(Channel):
|
||
def __init__(self, bus: MessageBus, config: dict[str, Any]) -> None:
|
||
super().__init__(name="wecom", bus=bus, config=config)
|
||
self._bot_id: str | None = None
|
||
self._bot_secret: str | None = None
|
||
self._ws_client = None
|
||
self._ws_task: asyncio.Task | None = None
|
||
self._ws_frames: dict[str, dict[str, Any]] = {}
|
||
self._ws_stream_ids: dict[str, str] = {}
|
||
self._working_message = "Working on it..."
|
||
self._connection_repo = config.get("connection_repo")
|
||
|
||
@property
|
||
def supports_streaming(self) -> bool:
|
||
return True
|
||
|
||
def _clear_ws_context(self, thread_ts: str | None) -> None:
|
||
if not thread_ts:
|
||
return
|
||
self._ws_frames.pop(thread_ts, None)
|
||
self._ws_stream_ids.pop(thread_ts, None)
|
||
|
||
async def _send_ws_upload_command(self, req_id: str, body: dict[str, Any], cmd: str) -> dict[str, Any]:
|
||
if not self._ws_client:
|
||
raise RuntimeError("WeCom WebSocket client is not available")
|
||
|
||
ws_manager = getattr(self._ws_client, "_ws_manager", None)
|
||
send_reply = getattr(ws_manager, "send_reply", None)
|
||
if not callable(send_reply):
|
||
raise RuntimeError("Installed wecom-aibot-python-sdk does not expose the WebSocket media upload API expected by DeerFlow. Use wecom-aibot-python-sdk==0.1.6 or update the adapter.")
|
||
|
||
send_reply_async = cast(Callable[[str, dict[str, Any], str], Awaitable[dict[str, Any]]], send_reply)
|
||
return await send_reply_async(req_id, body, cmd)
|
||
|
||
async def start(self) -> None:
|
||
if self._running:
|
||
return
|
||
|
||
bot_id = self.config.get("bot_id")
|
||
bot_secret = self.config.get("bot_secret")
|
||
working_message = self.config.get("working_message")
|
||
|
||
self._bot_id = bot_id if isinstance(bot_id, str) and bot_id else None
|
||
self._bot_secret = bot_secret if isinstance(bot_secret, str) and bot_secret else None
|
||
self._working_message = working_message if isinstance(working_message, str) and working_message else "Working on it..."
|
||
|
||
if not self._bot_id or not self._bot_secret:
|
||
logger.error("WeCom channel requires bot_id and bot_secret")
|
||
return
|
||
|
||
try:
|
||
from aibot import WSClient, WSClientOptions
|
||
except ImportError:
|
||
logger.error("wecom-aibot-python-sdk is not installed. Install it with: uv add wecom-aibot-python-sdk")
|
||
return
|
||
else:
|
||
self._ws_client = WSClient(WSClientOptions(bot_id=self._bot_id, secret=self._bot_secret, logger=logger))
|
||
self._ws_client.on("message.text", self._on_ws_text)
|
||
self._ws_client.on("message.mixed", self._on_ws_mixed)
|
||
self._ws_client.on("message.image", self._on_ws_image)
|
||
self._ws_client.on("message.file", self._on_ws_file)
|
||
self._ws_client.on("error", self._on_ws_error)
|
||
self._ws_client.on("disconnected", self._on_ws_disconnected)
|
||
self._ws_task = asyncio.create_task(self._ws_client.connect())
|
||
self._ws_task.add_done_callback(self._on_ws_task_done)
|
||
|
||
self._running = True
|
||
self.bus.subscribe_outbound(self._on_outbound)
|
||
logger.info("WeCom channel started")
|
||
|
||
def _on_ws_task_done(self, task: asyncio.Task) -> None:
|
||
if task.cancelled():
|
||
return
|
||
exc = task.exception()
|
||
if exc is None:
|
||
return
|
||
logger.error(
|
||
"WeCom WebSocket connection task failed: %s. Check that the network/proxy allows wss://openws.work.weixin.qq.com and that bot_id/bot_secret are valid.",
|
||
exc,
|
||
)
|
||
|
||
def _on_ws_error(self, error: Any) -> None:
|
||
logger.error("WeCom WebSocket error: %s", error)
|
||
|
||
def _on_ws_disconnected(self, *args: Any) -> None:
|
||
detail = f" ({args[0]})" if args else ""
|
||
logger.warning("WeCom WebSocket disconnected%s; SDK will attempt to reconnect", detail)
|
||
|
||
async def stop(self) -> None:
|
||
self._running = False
|
||
self.bus.unsubscribe_outbound(self._on_outbound)
|
||
if self._ws_task:
|
||
try:
|
||
self._ws_task.cancel()
|
||
except Exception:
|
||
pass
|
||
self._ws_task = None
|
||
if self._ws_client:
|
||
try:
|
||
self._ws_client.disconnect()
|
||
except Exception:
|
||
pass
|
||
self._ws_client = None
|
||
self._ws_frames.clear()
|
||
self._ws_stream_ids.clear()
|
||
logger.info("WeCom channel stopped")
|
||
|
||
async def send(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||
if self._ws_client:
|
||
await self._send_ws(msg, _max_retries=_max_retries)
|
||
return
|
||
logger.warning("[WeCom] send called but WebSocket client is not available")
|
||
|
||
async def _on_outbound(self, msg: OutboundMessage) -> None:
|
||
if msg.channel_name != self.name:
|
||
return
|
||
|
||
try:
|
||
await self.send(msg)
|
||
except Exception:
|
||
logger.exception("Failed to send outbound message on channel %s", self.name)
|
||
if msg.is_final:
|
||
self._clear_ws_context(msg.thread_ts)
|
||
return
|
||
|
||
for attachment in msg.attachments:
|
||
try:
|
||
success = await self.send_file(msg, attachment)
|
||
if not success:
|
||
logger.warning("[%s] file upload skipped for %s", self.name, attachment.filename)
|
||
except Exception:
|
||
logger.exception("[%s] failed to upload file %s", self.name, attachment.filename)
|
||
|
||
if msg.is_final:
|
||
self._clear_ws_context(msg.thread_ts)
|
||
|
||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||
if not msg.is_final:
|
||
return True
|
||
if not self._ws_client:
|
||
return False
|
||
if not msg.thread_ts:
|
||
return False
|
||
frame = self._ws_frames.get(msg.thread_ts)
|
||
if not frame:
|
||
return False
|
||
|
||
media_type = "image" if attachment.is_image else "file"
|
||
size_limit = 2 * 1024 * 1024 if attachment.is_image else 20 * 1024 * 1024
|
||
if attachment.size > size_limit:
|
||
logger.warning(
|
||
"[WeCom] %s too large (%d bytes), skipping: %s",
|
||
media_type,
|
||
attachment.size,
|
||
attachment.filename,
|
||
)
|
||
return False
|
||
|
||
try:
|
||
media_id = await self._upload_media_ws(
|
||
media_type=media_type,
|
||
filename=attachment.filename,
|
||
path=str(attachment.actual_path),
|
||
size=attachment.size,
|
||
)
|
||
if not media_id:
|
||
return False
|
||
|
||
body = {media_type: {"media_id": media_id}, "msgtype": media_type}
|
||
await self._ws_client.reply(frame, body)
|
||
logger.debug("[WeCom] %s sent via ws: %s", media_type, attachment.filename)
|
||
return True
|
||
except Exception:
|
||
logger.exception("[WeCom] failed to upload/send file via ws: %s", attachment.filename)
|
||
return False
|
||
|
||
async def _on_ws_text(self, frame: dict[str, Any]) -> None:
|
||
body = frame.get("body", {}) or {}
|
||
text = ((body.get("text") or {}).get("content") or "").strip()
|
||
quote = body.get("quote", {}).get("text", {}).get("content", "").strip()
|
||
if not text and not quote:
|
||
return
|
||
await self._publish_ws_inbound(frame, text + (f"\nQuote message: {quote}" if quote else ""))
|
||
|
||
async def _on_ws_mixed(self, frame: dict[str, Any]) -> None:
|
||
body = frame.get("body", {}) or {}
|
||
mixed = body.get("mixed") or {}
|
||
items = mixed.get("msg_item") or []
|
||
parts: list[str] = []
|
||
files: list[dict[str, Any]] = []
|
||
for item in items:
|
||
item_type = (item or {}).get("msgtype")
|
||
if item_type == "text":
|
||
content = (((item or {}).get("text") or {}).get("content") or "").strip()
|
||
if content:
|
||
parts.append(content)
|
||
elif item_type in ("image", "file"):
|
||
payload = (item or {}).get(item_type) or {}
|
||
url = payload.get("url")
|
||
aeskey = payload.get("aeskey")
|
||
if isinstance(url, str) and url:
|
||
files.append(
|
||
{
|
||
"type": item_type,
|
||
"url": url,
|
||
"aeskey": (aeskey if isinstance(aeskey, str) and aeskey else None),
|
||
}
|
||
)
|
||
text = "\n\n".join(parts).strip()
|
||
if not text and not files:
|
||
return
|
||
if not text:
|
||
text = "(receive image/file)"
|
||
await self._publish_ws_inbound(frame, text, files=files)
|
||
|
||
async def _on_ws_image(self, frame: dict[str, Any]) -> None:
|
||
body = frame.get("body", {}) or {}
|
||
image = body.get("image") or {}
|
||
url = image.get("url")
|
||
aeskey = image.get("aeskey")
|
||
if not isinstance(url, str) or not url:
|
||
return
|
||
await self._publish_ws_inbound(
|
||
frame,
|
||
"(receive image )",
|
||
files=[
|
||
{
|
||
"type": "image",
|
||
"url": url,
|
||
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
|
||
}
|
||
],
|
||
)
|
||
|
||
async def _on_ws_file(self, frame: dict[str, Any]) -> None:
|
||
body = frame.get("body", {}) or {}
|
||
file_obj = body.get("file") or {}
|
||
url = file_obj.get("url")
|
||
aeskey = file_obj.get("aeskey")
|
||
if not isinstance(url, str) or not url:
|
||
return
|
||
await self._publish_ws_inbound(
|
||
frame,
|
||
"(receive file)",
|
||
files=[
|
||
{
|
||
"type": "file",
|
||
"url": url,
|
||
"aeskey": aeskey if isinstance(aeskey, str) and aeskey else None,
|
||
}
|
||
],
|
||
)
|
||
|
||
async def _publish_ws_inbound(
|
||
self,
|
||
frame: dict[str, Any],
|
||
text: str,
|
||
*,
|
||
files: list[dict[str, Any]] | None = None,
|
||
) -> None:
|
||
if not self._ws_client:
|
||
return
|
||
try:
|
||
from aibot import generate_req_id
|
||
except Exception:
|
||
return
|
||
|
||
body = frame.get("body", {}) or {}
|
||
msg_id = body.get("msgid")
|
||
if not msg_id:
|
||
return
|
||
|
||
user_id = (body.get("from") or {}).get("userid")
|
||
|
||
connect_code = extract_connect_code(text)
|
||
if connect_code and self._connection_repo is not None:
|
||
handled = await self._bind_connection_from_connect_code(
|
||
frame=frame,
|
||
user_id=str(user_id or ""),
|
||
code=connect_code,
|
||
)
|
||
if handled:
|
||
return
|
||
|
||
inbound_type = InboundMessageType.COMMAND if is_known_channel_command(text) else InboundMessageType.CHAT
|
||
inbound = self._make_inbound(
|
||
chat_id=user_id, # keep user's conversation in memory
|
||
user_id=user_id,
|
||
text=text,
|
||
msg_type=inbound_type,
|
||
thread_ts=msg_id,
|
||
files=files or [],
|
||
metadata={
|
||
"aibotid": body.get("aibotid"),
|
||
"chattype": body.get("chattype"),
|
||
"message_id": msg_id,
|
||
},
|
||
)
|
||
inbound.topic_id = user_id # keep the same thread
|
||
|
||
stream_id = generate_req_id("stream")
|
||
self._ws_frames[msg_id] = frame
|
||
self._ws_stream_ids[msg_id] = stream_id
|
||
|
||
try:
|
||
await self._ws_client.reply_stream(frame, stream_id, self._working_message, False)
|
||
except Exception:
|
||
pass
|
||
|
||
inbound = await self._attach_connection_identity(inbound)
|
||
await self.bus.publish_inbound(inbound)
|
||
|
||
async def _attach_connection_identity(self, inbound: InboundMessage) -> InboundMessage:
|
||
return await attach_connection_identity(
|
||
inbound,
|
||
repo=self._connection_repo,
|
||
provider="wecom",
|
||
workspace_id=str(inbound.metadata.get("aibotid") or "") or None,
|
||
fallback_without_workspace=True,
|
||
)
|
||
|
||
async def _bind_connection_from_connect_code(self, *, frame: dict[str, Any], user_id: str, code: str) -> bool:
|
||
if self._connection_repo is None or not code:
|
||
return False
|
||
|
||
state = await self._connection_repo.consume_oauth_state(provider="wecom", state=code)
|
||
if state is None:
|
||
await self._send_connection_reply(frame, "WeCom connection code is invalid or expired.")
|
||
return True
|
||
|
||
if not user_id:
|
||
await self._send_connection_reply(frame, "WeCom connection could not be completed from this message.")
|
||
return True
|
||
|
||
body = frame.get("body", {}) or {}
|
||
workspace_id = str(body.get("aibotid") or "") or None
|
||
await self._connection_repo.upsert_connection(
|
||
owner_user_id=state["owner_user_id"],
|
||
provider="wecom",
|
||
external_account_id=user_id,
|
||
workspace_id=workspace_id,
|
||
metadata={
|
||
"aibotid": workspace_id,
|
||
"chattype": body.get("chattype"),
|
||
},
|
||
status="connected",
|
||
)
|
||
await self._send_connection_reply(frame, "WeCom connected to DeerFlow.")
|
||
return True
|
||
|
||
async def _send_connection_reply(self, frame: dict[str, Any], text: str) -> None:
|
||
if not self._ws_client:
|
||
return
|
||
await self._ws_client.reply(frame, {"msgtype": "text", "text": {"content": text}})
|
||
|
||
async def _send_ws(self, msg: OutboundMessage, *, _max_retries: int = 3) -> None:
|
||
if not self._ws_client:
|
||
return
|
||
try:
|
||
from aibot import generate_req_id
|
||
except Exception:
|
||
generate_req_id = None
|
||
|
||
if msg.thread_ts and msg.thread_ts in self._ws_frames:
|
||
frame = self._ws_frames[msg.thread_ts]
|
||
stream_id = self._ws_stream_ids.get(msg.thread_ts)
|
||
if not stream_id and generate_req_id:
|
||
stream_id = generate_req_id("stream")
|
||
self._ws_stream_ids[msg.thread_ts] = stream_id
|
||
if not stream_id:
|
||
return
|
||
|
||
await self._send_with_retry(
|
||
lambda: self._ws_client.reply_stream(frame, stream_id, msg.text, bool(msg.is_final)),
|
||
max_retries=_max_retries,
|
||
log_prefix="[WeCom]",
|
||
operation_name="stream send",
|
||
)
|
||
return
|
||
|
||
body = {"msgtype": "markdown", "markdown": {"content": msg.text}}
|
||
await self._send_with_retry(
|
||
lambda: self._ws_client.send_message(msg.chat_id, body),
|
||
max_retries=_max_retries,
|
||
log_prefix="[WeCom]",
|
||
)
|
||
|
||
async def _upload_media_ws(
|
||
self,
|
||
*,
|
||
media_type: str,
|
||
filename: str,
|
||
path: str,
|
||
size: int,
|
||
) -> str | None:
|
||
if not self._ws_client:
|
||
return None
|
||
try:
|
||
from aibot import generate_req_id
|
||
except Exception:
|
||
return None
|
||
|
||
chunk_size = 512 * 1024
|
||
total_chunks = (size + chunk_size - 1) // chunk_size
|
||
if total_chunks < 1 or total_chunks > 100:
|
||
logger.warning("[WeCom] invalid total_chunks=%d for %s", total_chunks, filename)
|
||
return None
|
||
|
||
md5_hasher = hashlib.md5()
|
||
with open(path, "rb") as f:
|
||
for chunk in iter(lambda: f.read(1024 * 1024), b""):
|
||
md5_hasher.update(chunk)
|
||
md5 = md5_hasher.hexdigest()
|
||
|
||
init_req_id = generate_req_id("aibot_upload_media_init")
|
||
init_body = {
|
||
"type": media_type,
|
||
"filename": filename,
|
||
"total_size": int(size),
|
||
"total_chunks": int(total_chunks),
|
||
"md5": md5,
|
||
}
|
||
init_ack = await self._send_ws_upload_command(init_req_id, init_body, "aibot_upload_media_init")
|
||
upload_id = (init_ack.get("body") or {}).get("upload_id")
|
||
if not upload_id:
|
||
logger.warning("[WeCom] upload init returned no upload_id: %s", init_ack)
|
||
return None
|
||
|
||
with open(path, "rb") as f:
|
||
for idx in range(total_chunks):
|
||
data = f.read(chunk_size)
|
||
if not data:
|
||
break
|
||
chunk_req_id = generate_req_id("aibot_upload_media_chunk")
|
||
chunk_body = {
|
||
"upload_id": upload_id,
|
||
"chunk_index": int(idx),
|
||
"base64_data": base64.b64encode(data).decode("utf-8"),
|
||
}
|
||
await self._send_ws_upload_command(chunk_req_id, chunk_body, "aibot_upload_media_chunk")
|
||
|
||
finish_req_id = generate_req_id("aibot_upload_media_finish")
|
||
finish_ack = await self._send_ws_upload_command(finish_req_id, {"upload_id": upload_id}, "aibot_upload_media_finish")
|
||
media_id = (finish_ack.get("body") or {}).get("media_id")
|
||
if not media_id:
|
||
logger.warning("[WeCom] upload finish returned no media_id: %s", finish_ack)
|
||
return None
|
||
return media_id
|