fix(channels): centralize shared channel retry helpers (#3583)

This commit is contained in:
Nan Gao
2026-06-17 09:44:40 +02:00
committed by GitHub
parent c81ab268fb
commit e732a741bf
8 changed files with 193 additions and 178 deletions
+53 -1
View File
@@ -2,14 +2,19 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from collections.abc import Awaitable, Callable
from concurrent.futures import CancelledError as FutureCancelledError
from typing import Any, TypeVar
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
T = TypeVar("T")
class Channel(ABC): class Channel(ABC):
"""Base class for all IM channel implementations. """Base class for all IM channel implementations.
@@ -65,6 +70,53 @@ class Channel(ABC):
# -- helpers ----------------------------------------------------------- # -- helpers -----------------------------------------------------------
async def _send_with_retry(
self,
operation: Callable[[], Awaitable[T]],
*,
max_retries: int,
log_prefix: str | None = None,
operation_name: str = "send",
) -> T:
"""Run an outbound send operation with the shared channel retry policy."""
prefix = log_prefix or f"[{self.name}]"
last_exc: Exception | None = None
for attempt in range(max_retries):
try:
return await operation()
except Exception as exc:
last_exc = exc
if attempt < max_retries - 1:
delay = 2**attempt
logger.warning(
"%s %s failed (attempt %d/%d), retrying in %ds: %s",
prefix,
operation_name,
attempt + 1,
max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("%s %s failed after %d attempts: %s", prefix, operation_name, max_retries, last_exc)
if last_exc is None:
raise RuntimeError(f"{self.name} {operation_name} failed without an exception from any attempt")
raise last_exc
def _log_future_error(self, fut: Any, name: str, msg_id: Any) -> None:
"""Callback for concurrent futures scheduled from channel worker threads."""
try:
exc = fut.exception()
except (asyncio.CancelledError, FutureCancelledError, asyncio.InvalidStateError):
return
except Exception:
logger.exception("[%s] failed to inspect future for %s (msg_id=%s)", self.name, name, msg_id)
return
if exc:
logger.error("[%s] %s failed for msg_id=%s: %s", self.name, name, msg_id, exc)
def _make_inbound( def _make_inbound(
self, self,
chat_id: str, chat_id: str,
+12 -34
View File
@@ -247,32 +247,19 @@ class DingTalkChannel(Channel):
self._card_repliers.pop(out_track_id, None) self._card_repliers.pop(out_track_id, None)
return return
# Non-card mode: send sampleMarkdown with retry async def send_markdown() -> None:
last_exc: Exception | None = None if conversation_type == _CONVERSATION_TYPE_GROUP:
for attempt in range(_max_retries): await self._send_group_message(robot_code, conversation_id, msg.text, at_user_ids=[sender_staff_id] if sender_staff_id else None)
try: else:
if conversation_type == _CONVERSATION_TYPE_GROUP: await self._send_p2p_message(robot_code, sender_staff_id, msg.text)
await self._send_group_message(robot_code, conversation_id, msg.text, at_user_ids=[sender_staff_id] if sender_staff_id else None)
else:
await self._send_p2p_message(robot_code, sender_staff_id, msg.text)
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
delay = 2**attempt
logger.warning(
"[DingTalk] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
_max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[DingTalk] send failed after %d attempts: %s", _max_retries, last_exc) # Non-card mode: send sampleMarkdown with retry
if last_exc is None: await self._send_with_retry(
raise RuntimeError("DingTalk send failed without an exception from any attempt") send_markdown,
raise last_exc max_retries=_max_retries,
log_prefix="[DingTalk]",
)
return
async def _send_markdown_fallback( async def _send_markdown_fallback(
self, self,
@@ -802,15 +789,6 @@ class DingTalkChannel(Channel):
logger.exception("[DingTalk] failed to upload media: %s", file_path) logger.exception("[DingTalk] failed to upload media: %s", file_path)
return None return None
@staticmethod
def _log_future_error(fut: Any, name: str, msg_id: str) -> None:
try:
exc = fut.exception()
if exc:
logger.error("[DingTalk] %s failed for msg_id=%s: %s", name, msg_id, exc)
except (asyncio.CancelledError, asyncio.InvalidStateError):
pass
class _DingTalkMessageHandler: class _DingTalkMessageHandler:
"""Callback handler registered with dingtalk-stream.""" """Callback handler registered with dingtalk-stream."""
+5 -32
View File
@@ -241,28 +241,11 @@ class FeishuChannel(Channel):
len(msg.text), len(msg.text),
) )
last_exc: Exception | None = None await self._send_with_retry(
for attempt in range(_max_retries): lambda: self._send_card_message(msg),
try: max_retries=_max_retries,
await self._send_card_message(msg) log_prefix="[Feishu]",
return # success )
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
delay = 2**attempt # 1s, 2s
logger.warning(
"[Feishu] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
_max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[Feishu] send failed after %d attempts: %s", _max_retries, last_exc)
if last_exc is None:
raise RuntimeError("Feishu send failed without an exception from any attempt")
raise last_exc
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._api_client: if not self._api_client:
@@ -725,16 +708,6 @@ class FeishuChannel(Channel):
return root_id or msg_id, False return root_id or msg_id, False
@staticmethod
def _log_future_error(fut, name: str, msg_id: str) -> None:
"""Callback for run_coroutine_threadsafe futures to surface errors."""
try:
exc = fut.exception()
if exc:
logger.error("[Feishu] %s failed for msg_id=%s: %s", name, msg_id, exc)
except Exception:
pass
@staticmethod @staticmethod
def _log_task_error(task: asyncio.Task, name: str, msg_id: str) -> None: def _log_task_error(task: asyncio.Task, name: str, msg_id: str) -> None:
"""Callback for background asyncio tasks to surface errors.""" """Callback for background asyncio tasks to surface errors."""
+26 -37
View File
@@ -141,49 +141,38 @@ class SlackChannel(Channel):
if msg.thread_ts: if msg.thread_ts:
kwargs["thread_ts"] = msg.thread_ts kwargs["thread_ts"] = msg.thread_ts
last_exc: Exception | None = None async def post_message() -> None:
for attempt in range(_max_retries): await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
try: # Add a completion reaction to the thread root
await asyncio.to_thread(web_client.chat_postMessage, **kwargs) if msg.thread_ts:
# Add a completion reaction to the thread root
if msg.thread_ts:
await asyncio.to_thread(
self._add_reaction_with_client,
web_client,
msg.chat_id,
msg.thread_ts,
"white_check_mark",
)
return
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
delay = 2**attempt # 1s, 2s
logger.warning(
"[Slack] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
_max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[Slack] send failed after %d attempts: %s", _max_retries, last_exc)
# Add failure reaction on error
if msg.thread_ts:
try:
await asyncio.to_thread( await asyncio.to_thread(
self._add_reaction_with_client, self._add_reaction_with_client,
web_client, web_client,
msg.chat_id, msg.chat_id,
msg.thread_ts, msg.thread_ts,
"x", "white_check_mark",
) )
except Exception:
pass try:
if last_exc is None: await self._send_with_retry(
raise RuntimeError("Slack send failed without an exception from any attempt") post_message,
raise last_exc max_retries=_max_retries,
log_prefix="[Slack]",
)
except Exception:
# Add failure reaction on error
if msg.thread_ts:
try:
await asyncio.to_thread(
self._add_reaction_with_client,
web_client,
msg.chat_id,
msg.thread_ts,
"x",
)
except Exception:
pass
raise
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
web_client = await self._get_web_client_for_message(msg) web_client = await self._get_web_client_for_message(msg)
+10 -32
View File
@@ -239,29 +239,17 @@ class TelegramChannel(Channel):
kwargs["reply_to_message_id"] = reply_to kwargs["reply_to_message_id"] = reply_to
bot = self._application.bot bot = self._application.bot
last_exc: Exception | None = None
for attempt in range(_max_retries):
try:
sent = await bot.send_message(**kwargs)
self._last_bot_message[chat_key] = sent.message_id
return sent.message_id
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
delay = 2**attempt # 1s, 2s
logger.warning(
"[Telegram] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
_max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[Telegram] send failed after %d attempts: %s", _max_retries, last_exc) async def send_message() -> int:
if last_exc is None: sent = await bot.send_message(**kwargs)
raise RuntimeError("Telegram send failed without an exception from any attempt") self._last_bot_message[chat_key] = sent.message_id
raise last_exc return sent.message_id
return await self._send_with_retry(
send_message,
max_retries=_max_retries,
log_prefix="[Telegram]",
)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if not self._application: if not self._application:
@@ -368,16 +356,6 @@ class TelegramChannel(Channel):
except Exception: except Exception:
logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id) logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id)
# -- internal ----------------------------------------------------------
@staticmethod
def _log_future_error(fut, name: str, msg_id: str):
try:
exc = fut.exception()
if exc:
logger.error("[Telegram] %s failed for msg_id=%s: %s", name, msg_id, exc)
except Exception:
logger.exception("[Telegram] Failed to inspect future for %s (msg_id=%s)", name, msg_id)
def _run_polling(self) -> None: def _run_polling(self) -> None:
"""Run telegram polling in a dedicated thread.""" """Run telegram polling in a dedicated thread."""
self._tg_loop = asyncio.new_event_loop() self._tg_loop = asyncio.new_event_loop()
+8 -20
View File
@@ -342,27 +342,15 @@ class WechatChannel(Channel):
"base_info": self._base_info(), "base_info": self._base_info(),
} }
last_exc: Exception | None = None async def send_message() -> None:
for attempt in range(max_retries): data = await self._request_json("/ilink/bot/sendmessage", payload)
try: self._ensure_success(data, "sendmessage")
data = await self._request_json("/ilink/bot/sendmessage", payload)
self._ensure_success(data, "sendmessage")
return
except Exception as exc:
last_exc = exc
if attempt < max_retries - 1:
delay = 2**attempt
logger.warning(
"[WeChat] send failed (attempt %d/%d), retrying in %ds: %s",
attempt + 1,
max_retries,
delay,
exc,
)
await asyncio.sleep(delay)
logger.error("[WeChat] send failed after %d attempts: %s", max_retries, last_exc) await self._send_with_retry(
raise last_exc # type: ignore[misc] send_message,
max_retries=max_retries,
log_prefix="[WeChat]",
)
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool: async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
if attachment.is_image: if attachment.is_image:
+12 -22
View File
@@ -389,30 +389,20 @@ class WeComChannel(Channel):
if not stream_id: if not stream_id:
return return
last_exc: Exception | None = None await self._send_with_retry(
for attempt in range(_max_retries): lambda: self._ws_client.reply_stream(frame, stream_id, msg.text, bool(msg.is_final)),
try: max_retries=_max_retries,
await self._ws_client.reply_stream(frame, stream_id, msg.text, bool(msg.is_final)) log_prefix="[WeCom]",
return operation_name="stream send",
except Exception as exc: )
last_exc = exc return
if attempt < _max_retries - 1:
await asyncio.sleep(2**attempt)
if last_exc:
raise last_exc
body = {"msgtype": "markdown", "markdown": {"content": msg.text}} body = {"msgtype": "markdown", "markdown": {"content": msg.text}}
last_exc = None await self._send_with_retry(
for attempt in range(_max_retries): lambda: self._ws_client.send_message(msg.chat_id, body),
try: max_retries=_max_retries,
await self._ws_client.send_message(msg.chat_id, body) log_prefix="[WeCom]",
return )
except Exception as exc:
last_exc = exc
if attempt < _max_retries - 1:
await asyncio.sleep(2**attempt)
if last_exc:
raise last_exc
async def _upload_media_ws( async def _upload_media_ws(
self, self,
+67
View File
@@ -4,7 +4,9 @@ from __future__ import annotations
import asyncio import asyncio
import json import json
import logging
import tempfile import tempfile
from concurrent.futures import Future
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@@ -333,6 +335,71 @@ class TestChannelBase:
_run(go()) _run(go())
def test_send_with_retry_retries_until_success(self, monkeypatch):
bus = MessageBus()
ch = DummyChannel(bus)
attempts = 0
sleep = AsyncMock()
monkeypatch.setattr("app.channels.base.asyncio.sleep", sleep)
async def flaky_send():
nonlocal attempts
attempts += 1
if attempts < 3:
raise RuntimeError(f"failure {attempts}")
return "sent"
result = _run(ch._send_with_retry(flaky_send, max_retries=3, log_prefix="[Dummy]"))
assert result == "sent"
assert attempts == 3
assert [call.args[0] for call in sleep.await_args_list] == [1, 2]
def test_log_future_error_handles_cancelled_future(self, caplog):
bus = MessageBus()
ch = DummyChannel(bus)
fut = Future()
fut.cancel()
with caplog.at_level(logging.ERROR):
ch._log_future_error(fut, "prepare_inbound", "m1")
assert "prepare_inbound" not in caplog.text
def test_log_future_error_surfaces_future_exception(self, caplog):
bus = MessageBus()
ch = DummyChannel(bus)
fut = Future()
fut.set_exception(RuntimeError("boom"))
with caplog.at_level(logging.ERROR):
ch._log_future_error(fut, "prepare_inbound", "m1")
assert "prepare_inbound failed for msg_id=m1: boom" in caplog.text
def test_channel_capabilities_match_channel_defaults(self):
from app.channels.dingtalk import DingTalkChannel
from app.channels.discord import DiscordChannel
from app.channels.feishu import FeishuChannel
from app.channels.manager import CHANNEL_CAPABILITIES
from app.channels.slack import SlackChannel
from app.channels.telegram import TelegramChannel
from app.channels.wechat import WechatChannel
from app.channels.wecom import WeComChannel
bus = MessageBus()
defaults = {
"dingtalk": DingTalkChannel(bus=bus, config={}).supports_streaming,
"discord": DiscordChannel(bus=bus, config={}).supports_streaming,
"feishu": FeishuChannel(bus=bus, config={}).supports_streaming,
"slack": SlackChannel(bus=bus, config={}).supports_streaming,
"telegram": TelegramChannel(bus=bus, config={}).supports_streaming,
"wechat": WechatChannel(bus=bus, config={}).supports_streaming,
"wecom": WeComChannel(bus=bus, config={}).supports_streaming,
}
assert {name: caps["supports_streaming"] for name, caps in CHANNEL_CAPABILITIES.items()} == defaults
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# _extract_response_text tests # _extract_response_text tests