mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 01:45:58 +00:00
Address Copilot IM channel feedback
This commit is contained in:
@@ -99,6 +99,10 @@ class SlackChannel(Channel):
|
|||||||
app_token = self.config.get("app_token", "")
|
app_token = self.config.get("app_token", "")
|
||||||
|
|
||||||
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
||||||
|
if not bot_token:
|
||||||
|
logger.error("Slack HTTP Events mode requires bot_token")
|
||||||
|
return
|
||||||
|
await self._initialize_operator_web_client(str(bot_token))
|
||||||
self._loop = asyncio.get_event_loop()
|
self._loop = asyncio.get_event_loop()
|
||||||
self._running = True
|
self._running = True
|
||||||
self.bus.subscribe_outbound(self._on_outbound)
|
self.bus.subscribe_outbound(self._on_outbound)
|
||||||
@@ -109,18 +113,7 @@ class SlackChannel(Channel):
|
|||||||
logger.error("Slack channel requires bot_token and app_token")
|
logger.error("Slack channel requires bot_token and app_token")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._web_client = self._web_client_factory(token=bot_token)
|
await self._initialize_operator_web_client(str(bot_token))
|
||||||
if self._bot_user_id is None:
|
|
||||||
try:
|
|
||||||
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
|
||||||
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
|
||||||
if user_id is None:
|
|
||||||
auth_get = getattr(auth_info, "get", None)
|
|
||||||
user_id = auth_get("user_id") if callable(auth_get) else None
|
|
||||||
if isinstance(user_id, str) and user_id:
|
|
||||||
self._bot_user_id = user_id
|
|
||||||
except Exception:
|
|
||||||
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
|
||||||
self._socket_client = SocketModeClient(
|
self._socket_client = SocketModeClient(
|
||||||
app_token=app_token,
|
app_token=app_token,
|
||||||
web_client=self._web_client,
|
web_client=self._web_client,
|
||||||
@@ -224,6 +217,21 @@ class SlackChannel(Channel):
|
|||||||
|
|
||||||
# -- internal ----------------------------------------------------------
|
# -- internal ----------------------------------------------------------
|
||||||
|
|
||||||
|
async def _initialize_operator_web_client(self, bot_token: str) -> None:
|
||||||
|
self._web_client = self._web_client_factory(token=bot_token)
|
||||||
|
if self._bot_user_id is not None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
auth_info = await asyncio.to_thread(self._web_client.auth_test)
|
||||||
|
user_id = auth_info.get("user_id") if isinstance(auth_info, dict) else None
|
||||||
|
if user_id is None:
|
||||||
|
auth_get = getattr(auth_info, "get", None)
|
||||||
|
user_id = auth_get("user_id") if callable(auth_get) else None
|
||||||
|
if isinstance(user_id, str) and user_id:
|
||||||
|
self._bot_user_id = user_id
|
||||||
|
except Exception:
|
||||||
|
logger.warning("[Slack] failed to resolve bot user id; app mention text may include the bot mention", exc_info=True)
|
||||||
|
|
||||||
async def _get_web_client_for_message(self, msg: OutboundMessage):
|
async def _get_web_client_for_message(self, msg: OutboundMessage):
|
||||||
if msg.connection_id and self._connection_repo is not None:
|
if msg.connection_id and self._connection_repo is not None:
|
||||||
credentials = await self._connection_repo.get_credentials(msg.connection_id)
|
credentials = await self._connection_repo.get_credentials(msg.connection_id)
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ def _provider_status(
|
|||||||
|
|
||||||
|
|
||||||
def _new_binding_code() -> str:
|
def _new_binding_code() -> str:
|
||||||
return secrets.token_hex(4)
|
return secrets.token_urlsafe(16)
|
||||||
|
|
||||||
|
|
||||||
async def _create_state(
|
async def _create_state(
|
||||||
@@ -216,7 +216,9 @@ async def get_channel_providers(request: Request) -> ChannelProvidersResponse:
|
|||||||
raise
|
raise
|
||||||
owner_user_id = _get_user_id(request)
|
owner_user_id = _get_user_id(request)
|
||||||
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
connections = await repo.list_connections(owner_user_id) if repo is not None else []
|
||||||
by_provider = {item["provider"]: item for item in connections}
|
by_provider: dict[str, dict[str, Any]] = {}
|
||||||
|
for item in connections:
|
||||||
|
by_provider.setdefault(item["provider"], item)
|
||||||
|
|
||||||
providers: list[ChannelProviderResponse] = []
|
providers: list[ChannelProviderResponse] = []
|
||||||
for provider, meta in _PROVIDER_META.items():
|
for provider, meta in _PROVIDER_META.items():
|
||||||
|
|||||||
@@ -101,6 +101,41 @@ def test_get_providers_reports_unconfigured_when_runtime_channel_is_missing(tmp_
|
|||||||
anyio.run(repo.close)
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_providers_uses_newest_connection_status_per_provider(tmp_path):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
repo = anyio.run(_make_repo, tmp_path)
|
||||||
|
|
||||||
|
async def seed_connections():
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id=str(_user().id),
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-old",
|
||||||
|
workspace_id="T-old",
|
||||||
|
status="revoked",
|
||||||
|
)
|
||||||
|
await anyio.sleep(0.01)
|
||||||
|
await repo.upsert_connection(
|
||||||
|
owner_user_id=str(_user().id),
|
||||||
|
provider="slack",
|
||||||
|
external_account_id="U-new",
|
||||||
|
workspace_id="T-new",
|
||||||
|
status="connected",
|
||||||
|
)
|
||||||
|
|
||||||
|
anyio.run(seed_connections)
|
||||||
|
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||||
|
|
||||||
|
with TestClient(app) as client:
|
||||||
|
response = client.get("/api/channels/providers")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||||
|
assert by_provider["slack"]["connection_status"] == "connected"
|
||||||
|
|
||||||
|
anyio.run(repo.close)
|
||||||
|
|
||||||
|
|
||||||
def test_get_connections_returns_current_user_connections_only(tmp_path):
|
def test_get_connections_returns_current_user_connections_only(tmp_path):
|
||||||
import anyio
|
import anyio
|
||||||
|
|
||||||
@@ -176,7 +211,7 @@ def test_connect_slack_returns_binding_command_and_persists_state(tmp_path):
|
|||||||
assert body["provider"] == "slack"
|
assert body["provider"] == "slack"
|
||||||
assert body["mode"] == "binding_code"
|
assert body["mode"] == "binding_code"
|
||||||
assert body["url"] is None
|
assert body["url"] is None
|
||||||
assert body["code"]
|
assert len(body["code"]) >= 22
|
||||||
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Slack bot."
|
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Slack bot."
|
||||||
|
|
||||||
async def count_states():
|
async def count_states():
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sys
|
||||||
from datetime import UTC, datetime, timedelta
|
from datetime import UTC, datetime, timedelta
|
||||||
|
from types import ModuleType
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
from app.channels.message_bus import MessageBus, OutboundMessage
|
from app.channels.message_bus import MessageBus, OutboundMessage
|
||||||
@@ -94,3 +96,59 @@ def test_slack_send_uses_connection_bot_token_when_connection_id_is_present():
|
|||||||
web_client.chat_postMessage.assert_called_once()
|
web_client.chat_postMessage.assert_called_once()
|
||||||
|
|
||||||
anyio.run(go)
|
anyio.run(go)
|
||||||
|
|
||||||
|
|
||||||
|
def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch):
|
||||||
|
import anyio
|
||||||
|
|
||||||
|
from app.channels.slack import SlackChannel
|
||||||
|
|
||||||
|
class FakeWebClient:
|
||||||
|
def __init__(self, token: str) -> None:
|
||||||
|
self.token = token
|
||||||
|
self.messages: list[dict] = []
|
||||||
|
|
||||||
|
def auth_test(self):
|
||||||
|
return {"user_id": "B-http"}
|
||||||
|
|
||||||
|
def chat_postMessage(self, **kwargs):
|
||||||
|
self.messages.append(kwargs)
|
||||||
|
|
||||||
|
slack_sdk = ModuleType("slack_sdk")
|
||||||
|
slack_sdk.WebClient = FakeWebClient
|
||||||
|
socket_mode = ModuleType("slack_sdk.socket_mode")
|
||||||
|
socket_mode.SocketModeClient = object
|
||||||
|
response = ModuleType("slack_sdk.socket_mode.response")
|
||||||
|
response.SocketModeResponse = object
|
||||||
|
monkeypatch.setitem(sys.modules, "slack_sdk", slack_sdk)
|
||||||
|
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode", socket_mode)
|
||||||
|
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode.response", response)
|
||||||
|
|
||||||
|
async def go():
|
||||||
|
channel = SlackChannel(
|
||||||
|
bus=MessageBus(),
|
||||||
|
config={
|
||||||
|
"bot_token": "xoxb-operator",
|
||||||
|
"event_delivery": "http",
|
||||||
|
"connection_repo": MagicMock(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await channel.start()
|
||||||
|
assert channel._running is True
|
||||||
|
assert channel._web_client is not None
|
||||||
|
assert channel._web_client.token == "xoxb-operator"
|
||||||
|
assert channel._bot_user_id == "B-http"
|
||||||
|
|
||||||
|
channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100")
|
||||||
|
|
||||||
|
assert channel._web_client.messages == [
|
||||||
|
{
|
||||||
|
"channel": "C123",
|
||||||
|
"text": "Slack connected to DeerFlow.",
|
||||||
|
"thread_ts": "1710000000.000100",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
await channel.stop()
|
||||||
|
|
||||||
|
anyio.run(go)
|
||||||
|
|||||||
Reference in New Issue
Block a user