mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-18 13:46:02 +00:00
fix(channels): centralize shared channel retry helpers (#3583)
This commit is contained in:
@@ -2,14 +2,19 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from collections.abc import Awaitable, Callable
|
||||||
|
from concurrent.futures import CancelledError as FutureCancelledError
|
||||||
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
from app.channels.message_bus import InboundMessage, InboundMessageType, MessageBus, OutboundMessage, ResolvedAttachment
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class Channel(ABC):
|
class Channel(ABC):
|
||||||
"""Base class for all IM channel implementations.
|
"""Base class for all IM channel implementations.
|
||||||
@@ -65,6 +70,53 @@ class Channel(ABC):
|
|||||||
|
|
||||||
# -- helpers -----------------------------------------------------------
|
# -- helpers -----------------------------------------------------------
|
||||||
|
|
||||||
|
async def _send_with_retry(
|
||||||
|
self,
|
||||||
|
operation: Callable[[], Awaitable[T]],
|
||||||
|
*,
|
||||||
|
max_retries: int,
|
||||||
|
log_prefix: str | None = None,
|
||||||
|
operation_name: str = "send",
|
||||||
|
) -> T:
|
||||||
|
"""Run an outbound send operation with the shared channel retry policy."""
|
||||||
|
prefix = log_prefix or f"[{self.name}]"
|
||||||
|
last_exc: Exception | None = None
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
return await operation()
|
||||||
|
except Exception as exc:
|
||||||
|
last_exc = exc
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
delay = 2**attempt
|
||||||
|
logger.warning(
|
||||||
|
"%s %s failed (attempt %d/%d), retrying in %ds: %s",
|
||||||
|
prefix,
|
||||||
|
operation_name,
|
||||||
|
attempt + 1,
|
||||||
|
max_retries,
|
||||||
|
delay,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
|
||||||
|
logger.error("%s %s failed after %d attempts: %s", prefix, operation_name, max_retries, last_exc)
|
||||||
|
if last_exc is None:
|
||||||
|
raise RuntimeError(f"{self.name} {operation_name} failed without an exception from any attempt")
|
||||||
|
raise last_exc
|
||||||
|
|
||||||
|
def _log_future_error(self, fut: Any, name: str, msg_id: Any) -> None:
|
||||||
|
"""Callback for concurrent futures scheduled from channel worker threads."""
|
||||||
|
try:
|
||||||
|
exc = fut.exception()
|
||||||
|
except (asyncio.CancelledError, FutureCancelledError, asyncio.InvalidStateError):
|
||||||
|
return
|
||||||
|
except Exception:
|
||||||
|
logger.exception("[%s] failed to inspect future for %s (msg_id=%s)", self.name, name, msg_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
if exc:
|
||||||
|
logger.error("[%s] %s failed for msg_id=%s: %s", self.name, name, msg_id, exc)
|
||||||
|
|
||||||
def _make_inbound(
|
def _make_inbound(
|
||||||
self,
|
self,
|
||||||
chat_id: str,
|
chat_id: str,
|
||||||
|
|||||||
@@ -247,32 +247,19 @@ class DingTalkChannel(Channel):
|
|||||||
self._card_repliers.pop(out_track_id, None)
|
self._card_repliers.pop(out_track_id, None)
|
||||||
return
|
return
|
||||||
|
|
||||||
# Non-card mode: send sampleMarkdown with retry
|
async def send_markdown() -> None:
|
||||||
last_exc: Exception | None = None
|
|
||||||
for attempt in range(_max_retries):
|
|
||||||
try:
|
|
||||||
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
if conversation_type == _CONVERSATION_TYPE_GROUP:
|
||||||
await self._send_group_message(robot_code, conversation_id, msg.text, at_user_ids=[sender_staff_id] if sender_staff_id else None)
|
await self._send_group_message(robot_code, conversation_id, msg.text, at_user_ids=[sender_staff_id] if sender_staff_id else None)
|
||||||
else:
|
else:
|
||||||
await self._send_p2p_message(robot_code, sender_staff_id, msg.text)
|
await self._send_p2p_message(robot_code, sender_staff_id, msg.text)
|
||||||
return
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt < _max_retries - 1:
|
|
||||||
delay = 2**attempt
|
|
||||||
logger.warning(
|
|
||||||
"[DingTalk] send failed (attempt %d/%d), retrying in %ds: %s",
|
|
||||||
attempt + 1,
|
|
||||||
_max_retries,
|
|
||||||
delay,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
logger.error("[DingTalk] send failed after %d attempts: %s", _max_retries, last_exc)
|
# Non-card mode: send sampleMarkdown with retry
|
||||||
if last_exc is None:
|
await self._send_with_retry(
|
||||||
raise RuntimeError("DingTalk send failed without an exception from any attempt")
|
send_markdown,
|
||||||
raise last_exc
|
max_retries=_max_retries,
|
||||||
|
log_prefix="[DingTalk]",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
async def _send_markdown_fallback(
|
async def _send_markdown_fallback(
|
||||||
self,
|
self,
|
||||||
@@ -802,15 +789,6 @@ class DingTalkChannel(Channel):
|
|||||||
logger.exception("[DingTalk] failed to upload media: %s", file_path)
|
logger.exception("[DingTalk] failed to upload media: %s", file_path)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _log_future_error(fut: Any, name: str, msg_id: str) -> None:
|
|
||||||
try:
|
|
||||||
exc = fut.exception()
|
|
||||||
if exc:
|
|
||||||
logger.error("[DingTalk] %s failed for msg_id=%s: %s", name, msg_id, exc)
|
|
||||||
except (asyncio.CancelledError, asyncio.InvalidStateError):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class _DingTalkMessageHandler:
|
class _DingTalkMessageHandler:
|
||||||
"""Callback handler registered with dingtalk-stream."""
|
"""Callback handler registered with dingtalk-stream."""
|
||||||
|
|||||||
@@ -241,28 +241,11 @@ class FeishuChannel(Channel):
|
|||||||
len(msg.text),
|
len(msg.text),
|
||||||
)
|
)
|
||||||
|
|
||||||
last_exc: Exception | None = None
|
await self._send_with_retry(
|
||||||
for attempt in range(_max_retries):
|
lambda: self._send_card_message(msg),
|
||||||
try:
|
max_retries=_max_retries,
|
||||||
await self._send_card_message(msg)
|
log_prefix="[Feishu]",
|
||||||
return # success
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt < _max_retries - 1:
|
|
||||||
delay = 2**attempt # 1s, 2s
|
|
||||||
logger.warning(
|
|
||||||
"[Feishu] send failed (attempt %d/%d), retrying in %ds: %s",
|
|
||||||
attempt + 1,
|
|
||||||
_max_retries,
|
|
||||||
delay,
|
|
||||||
exc,
|
|
||||||
)
|
)
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
logger.error("[Feishu] send failed after %d attempts: %s", _max_retries, last_exc)
|
|
||||||
if last_exc is None:
|
|
||||||
raise RuntimeError("Feishu send failed without an exception from any attempt")
|
|
||||||
raise last_exc
|
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||||
if not self._api_client:
|
if not self._api_client:
|
||||||
@@ -725,16 +708,6 @@ class FeishuChannel(Channel):
|
|||||||
|
|
||||||
return root_id or msg_id, False
|
return root_id or msg_id, False
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _log_future_error(fut, name: str, msg_id: str) -> None:
|
|
||||||
"""Callback for run_coroutine_threadsafe futures to surface errors."""
|
|
||||||
try:
|
|
||||||
exc = fut.exception()
|
|
||||||
if exc:
|
|
||||||
logger.error("[Feishu] %s failed for msg_id=%s: %s", name, msg_id, exc)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _log_task_error(task: asyncio.Task, name: str, msg_id: str) -> None:
|
def _log_task_error(task: asyncio.Task, name: str, msg_id: str) -> None:
|
||||||
"""Callback for background asyncio tasks to surface errors."""
|
"""Callback for background asyncio tasks to surface errors."""
|
||||||
|
|||||||
@@ -141,9 +141,7 @@ class SlackChannel(Channel):
|
|||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
kwargs["thread_ts"] = msg.thread_ts
|
kwargs["thread_ts"] = msg.thread_ts
|
||||||
|
|
||||||
last_exc: Exception | None = None
|
async def post_message() -> None:
|
||||||
for attempt in range(_max_retries):
|
|
||||||
try:
|
|
||||||
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
|
await asyncio.to_thread(web_client.chat_postMessage, **kwargs)
|
||||||
# Add a completion reaction to the thread root
|
# Add a completion reaction to the thread root
|
||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
@@ -154,21 +152,14 @@ class SlackChannel(Channel):
|
|||||||
msg.thread_ts,
|
msg.thread_ts,
|
||||||
"white_check_mark",
|
"white_check_mark",
|
||||||
)
|
)
|
||||||
return
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt < _max_retries - 1:
|
|
||||||
delay = 2**attempt # 1s, 2s
|
|
||||||
logger.warning(
|
|
||||||
"[Slack] send failed (attempt %d/%d), retrying in %ds: %s",
|
|
||||||
attempt + 1,
|
|
||||||
_max_retries,
|
|
||||||
delay,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
logger.error("[Slack] send failed after %d attempts: %s", _max_retries, last_exc)
|
try:
|
||||||
|
await self._send_with_retry(
|
||||||
|
post_message,
|
||||||
|
max_retries=_max_retries,
|
||||||
|
log_prefix="[Slack]",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
# Add failure reaction on error
|
# Add failure reaction on error
|
||||||
if msg.thread_ts:
|
if msg.thread_ts:
|
||||||
try:
|
try:
|
||||||
@@ -181,9 +172,7 @@ class SlackChannel(Channel):
|
|||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
if last_exc is None:
|
raise
|
||||||
raise RuntimeError("Slack send failed without an exception from any attempt")
|
|
||||||
raise last_exc
|
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||||
web_client = await self._get_web_client_for_message(msg)
|
web_client = await self._get_web_client_for_message(msg)
|
||||||
|
|||||||
@@ -239,29 +239,17 @@ class TelegramChannel(Channel):
|
|||||||
kwargs["reply_to_message_id"] = reply_to
|
kwargs["reply_to_message_id"] = reply_to
|
||||||
|
|
||||||
bot = self._application.bot
|
bot = self._application.bot
|
||||||
last_exc: Exception | None = None
|
|
||||||
for attempt in range(_max_retries):
|
async def send_message() -> int:
|
||||||
try:
|
|
||||||
sent = await bot.send_message(**kwargs)
|
sent = await bot.send_message(**kwargs)
|
||||||
self._last_bot_message[chat_key] = sent.message_id
|
self._last_bot_message[chat_key] = sent.message_id
|
||||||
return sent.message_id
|
return sent.message_id
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt < _max_retries - 1:
|
|
||||||
delay = 2**attempt # 1s, 2s
|
|
||||||
logger.warning(
|
|
||||||
"[Telegram] send failed (attempt %d/%d), retrying in %ds: %s",
|
|
||||||
attempt + 1,
|
|
||||||
_max_retries,
|
|
||||||
delay,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
logger.error("[Telegram] send failed after %d attempts: %s", _max_retries, last_exc)
|
return await self._send_with_retry(
|
||||||
if last_exc is None:
|
send_message,
|
||||||
raise RuntimeError("Telegram send failed without an exception from any attempt")
|
max_retries=_max_retries,
|
||||||
raise last_exc
|
log_prefix="[Telegram]",
|
||||||
|
)
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||||
if not self._application:
|
if not self._application:
|
||||||
@@ -368,16 +356,6 @@ class TelegramChannel(Channel):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id)
|
logger.exception("[Telegram] failed to send running reply in chat=%s", chat_id)
|
||||||
|
|
||||||
# -- internal ----------------------------------------------------------
|
|
||||||
@staticmethod
|
|
||||||
def _log_future_error(fut, name: str, msg_id: str):
|
|
||||||
try:
|
|
||||||
exc = fut.exception()
|
|
||||||
if exc:
|
|
||||||
logger.error("[Telegram] %s failed for msg_id=%s: %s", name, msg_id, exc)
|
|
||||||
except Exception:
|
|
||||||
logger.exception("[Telegram] Failed to inspect future for %s (msg_id=%s)", name, msg_id)
|
|
||||||
|
|
||||||
def _run_polling(self) -> None:
|
def _run_polling(self) -> None:
|
||||||
"""Run telegram polling in a dedicated thread."""
|
"""Run telegram polling in a dedicated thread."""
|
||||||
self._tg_loop = asyncio.new_event_loop()
|
self._tg_loop = asyncio.new_event_loop()
|
||||||
|
|||||||
@@ -342,27 +342,15 @@ class WechatChannel(Channel):
|
|||||||
"base_info": self._base_info(),
|
"base_info": self._base_info(),
|
||||||
}
|
}
|
||||||
|
|
||||||
last_exc: Exception | None = None
|
async def send_message() -> None:
|
||||||
for attempt in range(max_retries):
|
|
||||||
try:
|
|
||||||
data = await self._request_json("/ilink/bot/sendmessage", payload)
|
data = await self._request_json("/ilink/bot/sendmessage", payload)
|
||||||
self._ensure_success(data, "sendmessage")
|
self._ensure_success(data, "sendmessage")
|
||||||
return
|
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt < max_retries - 1:
|
|
||||||
delay = 2**attempt
|
|
||||||
logger.warning(
|
|
||||||
"[WeChat] send failed (attempt %d/%d), retrying in %ds: %s",
|
|
||||||
attempt + 1,
|
|
||||||
max_retries,
|
|
||||||
delay,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
await asyncio.sleep(delay)
|
|
||||||
|
|
||||||
logger.error("[WeChat] send failed after %d attempts: %s", max_retries, last_exc)
|
await self._send_with_retry(
|
||||||
raise last_exc # type: ignore[misc]
|
send_message,
|
||||||
|
max_retries=max_retries,
|
||||||
|
log_prefix="[WeChat]",
|
||||||
|
)
|
||||||
|
|
||||||
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
async def send_file(self, msg: OutboundMessage, attachment: ResolvedAttachment) -> bool:
|
||||||
if attachment.is_image:
|
if attachment.is_image:
|
||||||
|
|||||||
@@ -389,30 +389,20 @@ class WeComChannel(Channel):
|
|||||||
if not stream_id:
|
if not stream_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
last_exc: Exception | None = None
|
await self._send_with_retry(
|
||||||
for attempt in range(_max_retries):
|
lambda: self._ws_client.reply_stream(frame, stream_id, msg.text, bool(msg.is_final)),
|
||||||
try:
|
max_retries=_max_retries,
|
||||||
await self._ws_client.reply_stream(frame, stream_id, msg.text, bool(msg.is_final))
|
log_prefix="[WeCom]",
|
||||||
|
operation_name="stream send",
|
||||||
|
)
|
||||||
return
|
return
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt < _max_retries - 1:
|
|
||||||
await asyncio.sleep(2**attempt)
|
|
||||||
if last_exc:
|
|
||||||
raise last_exc
|
|
||||||
|
|
||||||
body = {"msgtype": "markdown", "markdown": {"content": msg.text}}
|
body = {"msgtype": "markdown", "markdown": {"content": msg.text}}
|
||||||
last_exc = None
|
await self._send_with_retry(
|
||||||
for attempt in range(_max_retries):
|
lambda: self._ws_client.send_message(msg.chat_id, body),
|
||||||
try:
|
max_retries=_max_retries,
|
||||||
await self._ws_client.send_message(msg.chat_id, body)
|
log_prefix="[WeCom]",
|
||||||
return
|
)
|
||||||
except Exception as exc:
|
|
||||||
last_exc = exc
|
|
||||||
if attempt < _max_retries - 1:
|
|
||||||
await asyncio.sleep(2**attempt)
|
|
||||||
if last_exc:
|
|
||||||
raise last_exc
|
|
||||||
|
|
||||||
async def _upload_media_ws(
|
async def _upload_media_ws(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from concurrent.futures import Future
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
@@ -333,6 +335,71 @@ class TestChannelBase:
|
|||||||
|
|
||||||
_run(go())
|
_run(go())
|
||||||
|
|
||||||
|
def test_send_with_retry_retries_until_success(self, monkeypatch):
|
||||||
|
bus = MessageBus()
|
||||||
|
ch = DummyChannel(bus)
|
||||||
|
attempts = 0
|
||||||
|
sleep = AsyncMock()
|
||||||
|
monkeypatch.setattr("app.channels.base.asyncio.sleep", sleep)
|
||||||
|
|
||||||
|
async def flaky_send():
|
||||||
|
nonlocal attempts
|
||||||
|
attempts += 1
|
||||||
|
if attempts < 3:
|
||||||
|
raise RuntimeError(f"failure {attempts}")
|
||||||
|
return "sent"
|
||||||
|
|
||||||
|
result = _run(ch._send_with_retry(flaky_send, max_retries=3, log_prefix="[Dummy]"))
|
||||||
|
|
||||||
|
assert result == "sent"
|
||||||
|
assert attempts == 3
|
||||||
|
assert [call.args[0] for call in sleep.await_args_list] == [1, 2]
|
||||||
|
|
||||||
|
def test_log_future_error_handles_cancelled_future(self, caplog):
|
||||||
|
bus = MessageBus()
|
||||||
|
ch = DummyChannel(bus)
|
||||||
|
fut = Future()
|
||||||
|
fut.cancel()
|
||||||
|
|
||||||
|
with caplog.at_level(logging.ERROR):
|
||||||
|
ch._log_future_error(fut, "prepare_inbound", "m1")
|
||||||
|
|
||||||
|
assert "prepare_inbound" not in caplog.text
|
||||||
|
|
||||||
|
def test_log_future_error_surfaces_future_exception(self, caplog):
|
||||||
|
bus = MessageBus()
|
||||||
|
ch = DummyChannel(bus)
|
||||||
|
fut = Future()
|
||||||
|
fut.set_exception(RuntimeError("boom"))
|
||||||
|
|
||||||
|
with caplog.at_level(logging.ERROR):
|
||||||
|
ch._log_future_error(fut, "prepare_inbound", "m1")
|
||||||
|
|
||||||
|
assert "prepare_inbound failed for msg_id=m1: boom" in caplog.text
|
||||||
|
|
||||||
|
def test_channel_capabilities_match_channel_defaults(self):
|
||||||
|
from app.channels.dingtalk import DingTalkChannel
|
||||||
|
from app.channels.discord import DiscordChannel
|
||||||
|
from app.channels.feishu import FeishuChannel
|
||||||
|
from app.channels.manager import CHANNEL_CAPABILITIES
|
||||||
|
from app.channels.slack import SlackChannel
|
||||||
|
from app.channels.telegram import TelegramChannel
|
||||||
|
from app.channels.wechat import WechatChannel
|
||||||
|
from app.channels.wecom import WeComChannel
|
||||||
|
|
||||||
|
bus = MessageBus()
|
||||||
|
defaults = {
|
||||||
|
"dingtalk": DingTalkChannel(bus=bus, config={}).supports_streaming,
|
||||||
|
"discord": DiscordChannel(bus=bus, config={}).supports_streaming,
|
||||||
|
"feishu": FeishuChannel(bus=bus, config={}).supports_streaming,
|
||||||
|
"slack": SlackChannel(bus=bus, config={}).supports_streaming,
|
||||||
|
"telegram": TelegramChannel(bus=bus, config={}).supports_streaming,
|
||||||
|
"wechat": WechatChannel(bus=bus, config={}).supports_streaming,
|
||||||
|
"wecom": WeComChannel(bus=bus, config={}).supports_streaming,
|
||||||
|
}
|
||||||
|
|
||||||
|
assert {name: caps["supports_streaming"] for name, caps in CHANNEL_CAPABILITIES.items()} == defaults
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# _extract_response_text tests
|
# _extract_response_text tests
|
||||||
|
|||||||
Reference in New Issue
Block a user