diff --git a/backend/app/channels/wecom.py b/backend/app/channels/wecom.py index badb0b525..1a2757ada 100644 --- a/backend/app/channels/wecom.py +++ b/backend/app/channels/wecom.py @@ -82,12 +82,33 @@ class WeComChannel(Channel): self._ws_client.on("message.mixed", self._on_ws_mixed) self._ws_client.on("message.image", self._on_ws_image) self._ws_client.on("message.file", self._on_ws_file) + self._ws_client.on("error", self._on_ws_error) + self._ws_client.on("disconnected", self._on_ws_disconnected) self._ws_task = asyncio.create_task(self._ws_client.connect()) + self._ws_task.add_done_callback(self._on_ws_task_done) self._running = True self.bus.subscribe_outbound(self._on_outbound) logger.info("WeCom channel started") + def _on_ws_task_done(self, task: asyncio.Task) -> None: + if task.cancelled(): + return + exc = task.exception() + if exc is None: + return + logger.error( + "WeCom WebSocket connection task failed: %s. Check that the network/proxy allows wss://openws.work.weixin.qq.com and that bot_id/bot_secret are valid.", + exc, + ) + + def _on_ws_error(self, error: Any) -> None: + logger.error("WeCom WebSocket error: %s", error) + + def _on_ws_disconnected(self, *args: Any) -> None: + detail = f" ({args[0]})" if args else "" + logger.warning("WeCom WebSocket disconnected%s; SDK will attempt to reconnect", detail) + async def stop(self) -> None: self._running = False self.bus.unsubscribe_outbound(self._on_outbound) diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 0ab033ba8..b4eea74ea 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -3368,6 +3368,121 @@ class TestWeComChannel: _run(go()) + def test_on_ws_task_done_logs_error_on_exception(self, caplog): + import logging + + from app.channels.wecom import WeComChannel + + channel = WeComChannel(MessageBus(), config={}) + task = MagicMock() + task.cancelled.return_value = False + task.exception.return_value = RuntimeError("boom") + + with caplog.at_level(logging.ERROR): + channel._on_ws_task_done(task) + + assert any("WeCom WebSocket connection task failed" in r.message and r.levelno == logging.ERROR for r in caplog.records) + + def test_on_ws_task_done_silent_when_cancelled(self, caplog): + import logging + + from app.channels.wecom import WeComChannel + + channel = WeComChannel(MessageBus(), config={}) + task = MagicMock() + task.cancelled.return_value = True + + with caplog.at_level(logging.ERROR): + channel._on_ws_task_done(task) + + task.exception.assert_not_called() + assert caplog.records == [] + + def test_on_ws_task_done_silent_when_no_exception(self, caplog): + import logging + + from app.channels.wecom import WeComChannel + + channel = WeComChannel(MessageBus(), config={}) + task = MagicMock() + task.cancelled.return_value = False + task.exception.return_value = None + + with caplog.at_level(logging.ERROR): + channel._on_ws_task_done(task) + + assert caplog.records == [] + + def test_on_ws_error_logs_error(self, caplog): + import logging + + from app.channels.wecom import WeComChannel + + channel = WeComChannel(MessageBus(), config={}) + + with caplog.at_level(logging.ERROR): + channel._on_ws_error(RuntimeError("handshake failed")) + + assert any("WeCom WebSocket error" in r.message and r.levelno == logging.ERROR for r in caplog.records) + + def test_on_ws_disconnected_logs_warning(self, caplog): + import logging + + from app.channels.wecom import WeComChannel + + channel = WeComChannel(MessageBus(), config={}) + + with caplog.at_level(logging.WARNING): + channel._on_ws_disconnected() + + assert any("WeCom WebSocket disconnected" in r.message and r.levelno == logging.WARNING for r in caplog.records) + + def test_on_ws_disconnected_logs_reason_when_present(self, caplog): + import logging + + from app.channels.wecom import WeComChannel + + channel = WeComChannel(MessageBus(), config={}) + + with caplog.at_level(logging.WARNING): + channel._on_ws_disconnected("connection reset") + + assert any("connection reset" in r.message and r.levelno == logging.WARNING for r in caplog.records) + + def test_start_subscribes_connection_lifecycle_events(self, monkeypatch): + from app.channels.wecom import WeComChannel + + async def go(): + bus = MessageBus() + channel = WeComChannel(bus, config={"bot_id": "corp123", "bot_secret": "secret"}) + + ws_client = MagicMock() + + async def fake_connect(): + return None + + ws_client.connect = fake_connect + + monkeypatch.setitem( + __import__("sys").modules, + "aibot", + SimpleNamespace( + WSClient=lambda options: ws_client, + WSClientOptions=lambda **kwargs: SimpleNamespace(**kwargs), + ), + ) + + await channel.start() + + subscribed_events = {call.args[0] for call in ws_client.on.call_args_list} + assert "error" in subscribed_events + assert "disconnected" in subscribed_events + assert channel._ws_task is not None + + await channel.stop() + + _run(go()) + class TestChannelService: def test_get_status_no_channels(self):