From 8c0830aea1d0aecd33cfd69d0a100a53ce5eece0 Mon Sep 17 00:00:00 2001 From: Nan Gao Date: Thu, 18 Jun 2026 04:09:46 +0200 Subject: [PATCH] fix(channels): add operational guardrails (#3584) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(channels): add operational guardrails * make format * fix(channels): converge with #3582 to avoid merge-order conflicts Drop this PR's DingTalk INFO-log redaction and hand it to #3582, which already restructures that handler and will redact the same log there. This PR no longer touches dingtalk.py, so the two PRs can merge to main in any order without a conflict. For WeChat, drop the contested thread_ts priority reorder (review #3) and keep only what inbound dedupe needs: a server-stable message_id in the inbound metadata (message_id/msg_id, no client_id per review #6). This is a single added line inside the metadata dict, a region #3582 never touches, so it auto-merges regardless of order. Co-Authored-By: Claude Opus 4.8 * fix(channels): address three correctness review findings 1. Connect-code cap was racy (willem #1): _create_state ran delete-expired, count, and insert as three separate transactions, so concurrent connect POSTs from one owner could each see count < cap and all insert past it. Add ChannelConnectionRepository.create_oauth_state_within_cap which does delete+count+insert in a single transaction serialized per (owner, provider) — Postgres via pg_advisory_xact_lock, SQLite via the write lock the leading DELETE takes — and have the router use it. 2. Inbound dedupe key fell back to "" workspace (willem #3): two workspaces delivering without team/guild/aibotid would collapse to the same key and dedupe each other's messages. _inbound_dedupe_key now fails closed (returns None) when no workspace identifier is present. 3. Dedupe key was recorded on receipt and never released on failure (ShenAC #1): a transient error (DB blip, Gateway 503) left the key in place for the full TTL, so a provider redelivery of the same message_id — exactly the retry dedupe should absorb — was silently dropped. _handle_message now releases the key in the unexpected-exception branch so redelivery can recover, while keeping record-on-receipt so retries during handling are still deduped. Tests: repo cap enforcement incl. concurrent-issuance non-leak; dedupe fail-closed; dedupe key release-on-failure redelivery recovery. Co-Authored-By: Claude Opus 4.8 * fix(channels): address cleanup/efficiency and test review findings Efficiency / cleanup: - Dedupe key set drops client-generated ids (client_msg_id, client_id); keep only server-stable event_id/message_id/msg_id, which a provider's own redelivery preserves (ShenAC #6). Every provider already emits message_id. - TTL/overflow pruning of _recent_inbound_events is now O(k): switch to an OrderedDict and popitem(last=False) from the front instead of scanning all 4096 entries on every inbound (willem #4). - Log "received inbound" only after the dedupe check so a provider retrying N times no longer logs N accepts; document that manager dedupe covers the agent run/final answer, not provider ack side-effects (willem #5, ShenAC #2). - Slack drops the redundant `team_id or event.get("team")` fallback the caller already resolved (willem #6). - create_oauth_state_within_cap prunes only this owner/provider's expired codes instead of a global DELETE on every connect POST; global cleanup still runs on consume_oauth_state (willem #7). Tests: - Dedupe test uses tmp_path instead of a leaked mkdtemp, uses distinct objects per publish, and adds a negative control: a different message_id is still processed, catching over-dedupe regressions (willem #8, ShenAC #4). - Slack HTTP-mode rejection test supplies app_token so the missing-token early return can't mask the guard, giving the state assertions teeth (ShenAC #3). - count_oauth_states test pins that the active row survives, not just the count (ShenAC #5). Co-Authored-By: Claude Opus 4.8 * make format --------- Co-authored-by: Claude Opus 4.8 --- backend/app/channels/feishu.py | 4 +- backend/app/channels/manager.py | 101 ++++++++++++++- backend/app/channels/slack.py | 19 ++- backend/app/channels/telegram.py | 2 + backend/app/channels/wechat.py | 1 + backend/app/channels/wecom.py | 6 +- .../gateway/routers/channel_connections.py | 15 ++- .../persistence/channel_connections/sql.py | 118 ++++++++++++++++- .../test_channel_connections_repository.py | 71 ++++++++++ .../tests/test_channel_connections_router.py | 21 +++ backend/tests/test_channels.py | 122 +++++++++++++++++- .../tests/test_slack_channel_connections.py | 39 ++---- 12 files changed, 468 insertions(+), 51 deletions(-) diff --git a/backend/app/channels/feishu.py b/backend/app/channels/feishu.py index b6b34424d..42da2dd27 100644 --- a/backend/app/channels/feishu.py +++ b/backend/app/channels/feishu.py @@ -837,14 +837,14 @@ class FeishuChannel(Channel): text = text.strip() logger.info( - "[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text=%r", + "[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text_len=%d", chat_id, msg_id, root_id, parent_id, feishu_thread_id, sender_id, - text[:100] if text else "", + len(text or ""), ) if not (text or files_list): diff --git a/backend/app/channels/manager.py b/backend/app/channels/manager.py index 5d1705524..df8ca2b28 100644 --- a/backend/app/channels/manager.py +++ b/backend/app/channels/manager.py @@ -7,6 +7,7 @@ import logging import mimetypes import re import time +from collections import OrderedDict from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass from pathlib import Path @@ -62,6 +63,12 @@ MESSAGE_STREAM_EVENTS = ("messages-tuple", "messages") THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again." BOUND_IDENTITY_REQUIRED_MESSAGE = "Connect this channel from DeerFlow Settings, complete the in-channel connect step, then send your message again." BOUND_IDENTITY_UNAVAILABLE_MESSAGE = "Channel connection verification is temporarily unavailable. Please try again later or contact the DeerFlow operator." +INBOUND_DEDUPE_TTL_SECONDS = 10 * 60 +INBOUND_DEDUPE_MAX_ENTRIES = 4096 +# Only server-stable provider message ids: client-generated ids (client_msg_id, +# client_id) are not guaranteed identical across a provider's own redelivery, so +# keying dedupe on them would miss exactly the retries we want to absorb. +INBOUND_DEDUPE_METADATA_KEYS = ("event_id", "message_id", "msg_id") CHANNEL_CAPABILITIES = { "dingtalk": {"supports_streaming": False}, @@ -774,6 +781,10 @@ class ChannelManager: self._semaphore: asyncio.Semaphore | None = None self._running = False self._task: asyncio.Task | None = None + # Insertion order == chronological (keys are never re-inserted), so an + # OrderedDict lets us evict expired/overflow entries from the front in + # O(k) instead of scanning all entries on every inbound message. + self._recent_inbound_events: OrderedDict[tuple[str, str, str, str], float] = OrderedDict() @staticmethod def _channel_supports_streaming(channel_name: str) -> bool: @@ -919,16 +930,94 @@ class ChannelManager: except asyncio.CancelledError: break + # Dedupe before logging "received" so a provider retrying an event N + # times does not log N accepts; duplicates are logged once as ignored. + # Note: this manager-level dedupe only guards the agent run / final + # answer. Provider adapters may emit ack side-effects (a "Working on + # it…" reply, an "eyes" reaction) before publish_inbound, so those are + # intentionally not deduped here. + if self._is_duplicate_inbound(msg): + continue logger.info( - "[Manager] received inbound: channel=%s, chat_id=%s, type=%s, text=%r", + "[Manager] received inbound: channel=%s, chat_id=%s, type=%s, text_len=%d, files=%d", msg.channel_name, msg.chat_id, msg.msg_type.value, - msg.text[:100] if msg.text else "", + len(msg.text or ""), + len(msg.files), ) task = asyncio.create_task(self._handle_message(msg)) task.add_done_callback(self._log_task_error) + @staticmethod + def _inbound_dedupe_key(msg: InboundMessage) -> tuple[str, str, str, str] | None: + metadata = msg.metadata or {} + message_id = None + for key in INBOUND_DEDUPE_METADATA_KEYS: + value = metadata.get(key) + if value: + message_id = str(value) + break + if message_id is None: + raw_message = metadata.get("raw_message") + if isinstance(raw_message, Mapping): + for key in INBOUND_DEDUPE_METADATA_KEYS: + value = raw_message.get(key) + if value: + message_id = str(value) + break + if message_id is None: + return None + + # Fail closed: without a workspace/team/guild identifier we cannot tell two + # workspaces apart (e.g. Slack channel ids are not globally unique), so + # skip dedupe rather than risk collapsing distinct workspaces' messages. + workspace_id = msg.workspace_id or metadata.get("workspace_id") or metadata.get("team_id") or metadata.get("guild_id") or metadata.get("aibotid") + if not workspace_id: + return None + return (msg.channel_name, str(workspace_id), msg.chat_id, message_id) + + def _is_duplicate_inbound(self, msg: InboundMessage) -> bool: + key = self._inbound_dedupe_key(msg) + if key is None: + return False + + now = time.monotonic() + # Entries are in chronological insertion order, so expired ones cluster at + # the front: pop from the front until we hit a still-live entry. + while self._recent_inbound_events: + _, oldest_at = next(iter(self._recent_inbound_events.items())) + if now - oldest_at > INBOUND_DEDUPE_TTL_SECONDS: + self._recent_inbound_events.popitem(last=False) + else: + break + while len(self._recent_inbound_events) > INBOUND_DEDUPE_MAX_ENTRIES: + self._recent_inbound_events.popitem(last=False) + + if key in self._recent_inbound_events: + logger.info( + "[Manager] duplicate inbound ignored: channel=%s, chat_id=%s, message_id=%s", + msg.channel_name, + msg.chat_id, + key[-1], + ) + return True + + self._recent_inbound_events[key] = now + return False + + def _release_inbound_dedupe_key(self, msg: InboundMessage) -> None: + """Drop a recorded dedupe key so a provider redelivery can be reprocessed. + + Called only on transient/unexpected handling failures: the key was + recorded on receipt so retries arriving *while* the message is being + handled are still deduped, but if handling fails we must not turn a + recoverable error into a TTL-long black hole for the same message_id. + """ + key = self._inbound_dedupe_key(msg) + if key is not None: + self._recent_inbound_events.pop(key, None) + @staticmethod def _log_task_error(task: asyncio.Task) -> None: """Surface unhandled exceptions from background tasks.""" @@ -979,6 +1068,10 @@ class ChannelManager: msg.channel_name, msg.chat_id, ) + # Transient/unexpected failure: release the dedupe key so a provider + # redelivery of the same message can recover instead of being dropped + # for the dedupe TTL. + self._release_inbound_dedupe_key(msg) await self._send_error(msg, "An internal error occurred. Please try again.") # -- chat handling ----------------------------------------------------- @@ -1169,7 +1262,7 @@ class ChannelManager: ) return - logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) + logger.info("[Manager] invoking runs.wait(thread_id=%s, text_len=%d)", thread_id, len(msg.text or "")) run_kwargs: dict[str, Any] = { "input": {"messages": [human_message]}, "config": run_config, @@ -1236,7 +1329,7 @@ class ChannelManager: run_context: dict[str, Any], human_message: dict[str, Any], ) -> None: - logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100]) + logger.info("[Manager] invoking runs.stream(thread_id=%s, text_len=%d)", thread_id, len(msg.text or "")) last_values: dict[str, Any] | list | None = None streamed_buffers: dict[str, str] = {} diff --git a/backend/app/channels/slack.py b/backend/app/channels/slack.py index 96141f3ca..34266515e 100644 --- a/backend/app/channels/slack.py +++ b/backend/app/channels/slack.py @@ -90,15 +90,8 @@ class SlackChannel(Channel): bot_token = self.config.get("bot_token", "") app_token = self.config.get("app_token", "") - if self._connection_repo is not None and self.config.get("event_delivery") == "http": - if not bot_token: - logger.error("Slack HTTP Events mode requires bot_token") - return - await self._initialize_operator_web_client(str(bot_token)) - self._loop = asyncio.get_event_loop() - self._running = True - self.bus.subscribe_outbound(self._on_outbound) - logger.info("Slack channel started in HTTP Events mode") + if self.config.get("event_delivery") == "http": + logger.error("Slack HTTP Events mode is not supported by this channel adapter; use Socket Mode with app_token") return if not bot_token or not app_token: @@ -319,7 +312,7 @@ class SlackChannel(Channel): asyncio.run_coroutine_threadsafe( self._bind_connection_from_connect_code( event=event, - team_id=str(team_id or event.get("team") or ""), + team_id=str(team_id or ""), code=connect_code, ), self._loop, @@ -343,6 +336,12 @@ class SlackChannel(Channel): text=text, msg_type=msg_type, thread_ts=thread_ts, + metadata={ + # team_id is already resolved (payload team_id/team, else event team) by the caller. + "team_id": team_id, + "message_id": event.get("ts"), + "client_msg_id": event.get("client_msg_id"), + }, ) inbound.topic_id = thread_ts diff --git a/backend/app/channels/telegram.py b/backend/app/channels/telegram.py index 0f92f0461..0061c0d1f 100644 --- a/backend/app/channels/telegram.py +++ b/backend/app/channels/telegram.py @@ -503,6 +503,7 @@ class TelegramChannel(Channel): text=text, msg_type=InboundMessageType.COMMAND, thread_ts=msg_id, + metadata={"message_id": msg_id}, ) inbound.topic_id = topic_id inbound = await self._attach_connection_identity(inbound) @@ -546,6 +547,7 @@ class TelegramChannel(Channel): text=text, msg_type=InboundMessageType.CHAT, thread_ts=msg_id, + metadata={"message_id": msg_id}, ) inbound.topic_id = topic_id inbound = await self._attach_connection_identity(inbound) diff --git a/backend/app/channels/wechat.py b/backend/app/channels/wechat.py index f2db2c380..0f8a61122 100644 --- a/backend/app/channels/wechat.py +++ b/backend/app/channels/wechat.py @@ -627,6 +627,7 @@ class WechatChannel(Channel): metadata={ "context_token": context_token, "ilink_user_id": chat_id, + "message_id": str(raw_message.get("message_id") or raw_message.get("msg_id") or "").strip(), "ref_msg": self._extract_ref_message(raw_message), "raw_message": raw_message, }, diff --git a/backend/app/channels/wecom.py b/backend/app/channels/wecom.py index 19997ed54..5b287f536 100644 --- a/backend/app/channels/wecom.py +++ b/backend/app/channels/wecom.py @@ -313,7 +313,11 @@ class WeComChannel(Channel): msg_type=inbound_type, thread_ts=msg_id, files=files or [], - metadata={"aibotid": body.get("aibotid"), "chattype": body.get("chattype")}, + metadata={ + "aibotid": body.get("aibotid"), + "chattype": body.get("chattype"), + "message_id": msg_id, + }, ) inbound.topic_id = user_id # keep the same thread diff --git a/backend/app/gateway/routers/channel_connections.py b/backend/app/gateway/routers/channel_connections.py index ea57f5126..2b9c9a6b2 100644 --- a/backend/app/gateway/routers/channel_connections.py +++ b/backend/app/gateway/routers/channel_connections.py @@ -24,6 +24,7 @@ router = APIRouter(prefix="/api/channels", tags=["channel-connections"]) logger = logging.getLogger(__name__) _STATE_TTL_SECONDS = 600 +_MAX_PENDING_CONNECT_CODES_PER_PROVIDER = 5 _MASKED_CREDENTIAL_VALUE = "********" @@ -332,13 +333,23 @@ async def _create_state( owner_user_id: str, provider: str, ) -> str: + now = datetime.now(UTC) state = _new_binding_code() - await repo.create_oauth_state( + # Atomic delete-expired + count + insert so concurrent connect POSTs from one + # owner cannot each see count < cap and all insert past the cap. + inserted = await repo.create_oauth_state_within_cap( owner_user_id=owner_user_id, provider=provider, state=state, - expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS), + expires_at=now + timedelta(seconds=_STATE_TTL_SECONDS), + max_pending=_MAX_PENDING_CONNECT_CODES_PER_PROVIDER, + now=now, ) + if not inserted: + raise HTTPException( + status_code=429, + detail="Too many pending channel connection codes. Wait for existing codes to expire or use one of them.", + ) return state diff --git a/backend/packages/harness/deerflow/persistence/channel_connections/sql.py b/backend/packages/harness/deerflow/persistence/channel_connections/sql.py index bc71926fa..e48a48ff7 100644 --- a/backend/packages/harness/deerflow/persistence/channel_connections/sql.py +++ b/backend/packages/harness/deerflow/persistence/channel_connections/sql.py @@ -11,7 +11,7 @@ from datetime import UTC, datetime from typing import Any from cryptography.fernet import Fernet, InvalidToken -from sqlalchemy import delete, func, select, update +from sqlalchemy import delete, func, select, text, update from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker @@ -279,16 +279,128 @@ class ChannelConnectionRepository: session.add(row) await session.commit() - async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int: + async def create_oauth_state_within_cap( + self, + *, + owner_user_id: str, + provider: str, + state: str, + expires_at: datetime, + max_pending: int, + now: datetime | None = None, + code_verifier: str | None = None, + nonce_hash: str | None = None, + redirect_after: str | None = None, + requested_scopes: list[str] | None = None, + metadata: dict[str, Any] | None = None, + ) -> bool: + """Atomically enforce the per-(owner, provider) pending cap, then insert. + + delete-expired + count + insert run in a single transaction serialized + per (owner, provider), so concurrent connect requests cannot each + observe ``count < max_pending`` and all insert (which would leak past + the cap). PostgreSQL takes a transaction-scoped advisory lock; SQLite + serializes writers through the write lock the leading DELETE acquires. + + Returns ``True`` when the row was inserted, ``False`` when the cap is + already reached. + """ + current_time = now or datetime.now(UTC) async with self.session_factory() as session: - result = await session.execute( + await self._serialize_oauth_owner_scope(session, owner_user_id, provider) + # Prune only this owner/provider's expired codes (the ones that affect + # this cap), not every user's — avoids a global DELETE on each connect + # POST. Issuing this write first also takes the SQLite database write + # lock so the count below cannot race a concurrent inserter between + # count and commit. Stale codes for other owners are pruned globally + # by consume_oauth_state / delete_expired_oauth_states. + await session.execute( + delete(ChannelOAuthStateRow).where( + ChannelOAuthStateRow.owner_user_id == owner_user_id, + ChannelOAuthStateRow.provider == provider, + ChannelOAuthStateRow.expires_at < current_time, + ) + ) + pending = await session.execute( select(func.count()) .select_from(ChannelOAuthStateRow) .where( ChannelOAuthStateRow.owner_user_id == owner_user_id, ChannelOAuthStateRow.provider == provider, + ChannelOAuthStateRow.consumed_at.is_(None), + ChannelOAuthStateRow.expires_at >= current_time, ) ) + if int(pending.scalar_one()) >= max_pending: + await session.rollback() + return False + session.add( + ChannelOAuthStateRow( + state_hash=self.hash_state(state), + owner_user_id=owner_user_id, + provider=provider, + code_verifier_encrypted=self._encrypt_optional_secret(code_verifier), + nonce_hash=nonce_hash, + redirect_after=redirect_after, + requested_scopes_json=list(requested_scopes or []), + metadata_json=dict(metadata or {}), + expires_at=expires_at, + ) + ) + await session.commit() + return True + + async def _serialize_oauth_owner_scope(self, session: AsyncSession, owner_user_id: str, provider: str) -> None: + """Serialize concurrent pending-cap transactions for one (owner, provider). + + On PostgreSQL this takes a transaction-scoped advisory lock so concurrent + issuers run their count+insert one at a time. On SQLite the leading + DELETE in the caller's transaction already acquires the database write + lock, which serializes writers, so no extra lock is required. + """ + try: + dialect = session.bind.dialect.name if session.bind is not None else "" + except Exception: + dialect = "" + if dialect == "postgresql": + await session.execute(text("SELECT pg_advisory_xact_lock(:lock_key)"), {"lock_key": self._oauth_scope_lock_key(owner_user_id, provider)}) + + @staticmethod + def _oauth_scope_lock_key(owner_user_id: str, provider: str) -> int: + digest = hashlib.sha256(f"{owner_user_id}\x00{provider}".encode()).digest() + # 63-bit non-negative key for pg_advisory_xact_lock(bigint). + return int.from_bytes(digest[:8], "big") & 0x7FFFFFFFFFFFFFFF + + async def delete_expired_oauth_states(self, *, now: datetime | None = None) -> int: + current_time = now or datetime.now(UTC) + async with self.session_factory() as session: + result = await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time)) + await session.commit() + return int(result.rowcount or 0) + + async def count_oauth_states( + self, + *, + owner_user_id: str, + provider: str, + active_only: bool = False, + now: datetime | None = None, + ) -> int: + current_time = now or datetime.now(UTC) + conditions = [ + ChannelOAuthStateRow.owner_user_id == owner_user_id, + ChannelOAuthStateRow.provider == provider, + ] + if active_only: + conditions.extend( + [ + ChannelOAuthStateRow.consumed_at.is_(None), + ChannelOAuthStateRow.expires_at >= current_time, + ] + ) + + async with self.session_factory() as session: + result = await session.execute(select(func.count()).select_from(ChannelOAuthStateRow).where(*conditions)) return int(result.scalar_one()) async def consume_oauth_state( diff --git a/backend/tests/test_channel_connections_repository.py b/backend/tests/test_channel_connections_repository.py index ae5610f89..ab873ee3c 100644 --- a/backend/tests/test_channel_connections_repository.py +++ b/backend/tests/test_channel_connections_repository.py @@ -246,6 +246,77 @@ class TestChannelConnectionRepository: states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all() assert [state.state_hash for state in states] == [repo.hash_state("active-state")] + @pytest.mark.anyio + async def test_count_oauth_states_active_only_and_delete_expired(self, repo): + now = datetime.now(UTC) + await repo.create_oauth_state( + owner_user_id="alice", + provider="slack", + state="expired-state", + expires_at=now - timedelta(minutes=1), + ) + await repo.create_oauth_state( + owner_user_id="alice", + provider="slack", + state="active-state", + expires_at=now + timedelta(minutes=5), + ) + + assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 1 + assert await repo.delete_expired_oauth_states(now=now) == 1 + assert await repo.count_oauth_states(owner_user_id="alice", provider="slack") == 1 + # Pin that the surviving row is the active one (an inverted expiry + # predicate would delete the active row, still return 1, and pass above). + async with repo.session_factory() as session: + survivors = (await session.execute(select(ChannelOAuthStateRow))).scalars().all() + assert [row.state_hash for row in survivors] == [repo.hash_state("active-state")] + + @pytest.mark.anyio + async def test_create_oauth_state_within_cap_enforces_pending_cap(self, repo): + now = datetime.now(UTC) + expires = now + timedelta(minutes=5) + + for i in range(3): + inserted = await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state=f"code-{i}", expires_at=expires, max_pending=3, now=now) + assert inserted is True + + # Cap reached: the next issuance is rejected and nothing is inserted. + assert await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state="code-over", expires_at=expires, max_pending=3, now=now) is False + assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 3 + + # Expired rows are pruned and free up capacity; a different owner is unaffected. + assert await repo.create_oauth_state_within_cap(owner_user_id="bob", provider="slack", state="bob-1", expires_at=expires, max_pending=3, now=now) is True + + @pytest.mark.anyio + async def test_create_oauth_state_within_cap_ignores_expired_rows(self, repo): + now = datetime.now(UTC) + # Three already-expired rows must not count against the cap. + for i in range(3): + await repo.create_oauth_state(owner_user_id="alice", provider="slack", state=f"old-{i}", expires_at=now - timedelta(minutes=1)) + + inserted = await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state="fresh", expires_at=now + timedelta(minutes=5), max_pending=3, now=now) + assert inserted is True + assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 1 + + @pytest.mark.anyio + async def test_create_oauth_state_within_cap_does_not_leak_under_concurrency(self, repo): + """Concurrent issuance for one owner cannot push past the cap (willem #1).""" + import anyio + + now = datetime.now(UTC) + expires = now + timedelta(minutes=5) + results: list[bool] = [] + + async def issue(state: str) -> None: + results.append(await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state=state, expires_at=expires, max_pending=3, now=now)) + + async with anyio.create_task_group() as tg: + for i in range(8): + tg.start_soon(issue, f"code-{i}") + + assert sum(1 for ok in results if ok) == 3 + assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 3 + @pytest.mark.anyio async def test_consume_oauth_state_is_one_time_even_under_concurrent_consumers(self, repo): import anyio diff --git a/backend/tests/test_channel_connections_router.py b/backend/tests/test_channel_connections_router.py index 638fbf215..046568f10 100644 --- a/backend/tests/test_channel_connections_router.py +++ b/backend/tests/test_channel_connections_router.py @@ -504,6 +504,27 @@ def test_connect_slack_returns_binding_command_and_persists_state(tmp_path): anyio.run(repo.close) +def test_connect_binding_code_caps_pending_states_per_provider(tmp_path): + import anyio + + repo = anyio.run(_make_repo, tmp_path) + app = _make_app(_enabled_connections_config(), repo, _channels_config()) + + with TestClient(app) as client: + responses = [client.post("/api/channels/slack/connect") for _ in range(6)] + + assert [response.status_code for response in responses[:5]] == [200, 200, 200, 200, 200] + assert responses[5].status_code == 429 + assert "Too many pending channel connection codes" in responses[5].json()["detail"] + + async def count_states(): + return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="slack") + + assert anyio.run(count_states) == 5 + + anyio.run(repo.close) + + def test_connect_discord_returns_binding_command_and_persists_state(tmp_path): import anyio diff --git a/backend/tests/test_channels.py b/backend/tests/test_channels.py index 5c99f7c9f..5db92c936 100644 --- a/backend/tests/test_channels.py +++ b/backend/tests/test_channels.py @@ -800,6 +800,126 @@ class TestChannelManager: _run(go()) + def test_dispatch_loop_dedupes_stable_provider_message_id(self, tmp_path): + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=tmp_path / "store.json") + manager = ChannelManager(bus=bus, store=store) + manager._client = _make_mock_langgraph_client() + outbound_received: list[OutboundMessage] = [] + + async def capture_outbound(msg: OutboundMessage) -> None: + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + def _slack_inbound(message_id: str) -> InboundMessage: + # Distinct objects per publish, like a real provider redelivery. + return InboundMessage( + channel_name="slack", + chat_id="C123", + user_id="U123", + text="sensitive prompt", + topic_id="1710000000.000100", + metadata={"team_id": "T123", "message_id": message_id}, + ) + + # Same stable message_id delivered twice -> processed once. + await bus.publish_inbound(_slack_inbound("1710000000.000200")) + await bus.publish_inbound(_slack_inbound("1710000000.000200")) + await _wait_for(lambda: manager._client.runs.wait.call_count == 1 and len(outbound_received) == 1) + await asyncio.sleep(0.05) + assert manager._client.threads.create.call_count == 1 + assert manager._client.runs.wait.call_count == 1 + assert len(outbound_received) == 1 + + # Negative control: a *different* message_id must still be processed, + # so an over-dedupe regression (dropping distinct messages) is caught. + await bus.publish_inbound(_slack_inbound("1710000000.000999")) + await _wait_for(lambda: manager._client.runs.wait.call_count == 2 and len(outbound_received) == 2) + await asyncio.sleep(0.05) + await manager.stop() + + assert manager._client.runs.wait.call_count == 2 + assert len(outbound_received) == 2 + + _run(go()) + + def test_inbound_dedupe_key_fails_closed_without_workspace(self): + """Without a workspace identifier, skip dedupe instead of collapsing workspaces (willem #3).""" + from app.channels.manager import ChannelManager + + with_workspace = InboundMessage( + channel_name="slack", + chat_id="C1", + user_id="U1", + text="x", + metadata={"team_id": "T1", "message_id": "m1"}, + ) + assert ChannelManager._inbound_dedupe_key(with_workspace) == ("slack", "T1", "C1", "m1") + + without_workspace = InboundMessage( + channel_name="slack", + chat_id="C1", + user_id="U1", + text="x", + metadata={"message_id": "m1"}, + ) + assert ChannelManager._inbound_dedupe_key(without_workspace) is None + + def test_dispatch_loop_releases_dedupe_key_when_handling_fails(self, tmp_path): + """A transient handling failure must not black-hole a provider redelivery (ShenAC #1).""" + from app.channels.manager import ChannelManager + + async def go(): + bus = MessageBus() + store = ChannelStore(path=tmp_path / "store.json") + manager = ChannelManager(bus=bus, store=store) + client = _make_mock_langgraph_client() + attempts = {"n": 0} + + async def flaky_wait(*args, **kwargs): + attempts["n"] += 1 + if attempts["n"] == 1: + raise RuntimeError("transient gateway 503") + return {"messages": [{"type": "human", "content": "hi"}, {"type": "ai", "content": "recovered"}]} + + client.runs.wait = AsyncMock(side_effect=flaky_wait) + manager._client = client + + outbound_received: list[OutboundMessage] = [] + + async def capture_outbound(msg: OutboundMessage) -> None: + outbound_received.append(msg) + + bus.subscribe_outbound(capture_outbound) + await manager.start() + + inbound = InboundMessage( + channel_name="slack", + chat_id="C123", + user_id="U123", + text="hello", + metadata={"team_id": "T123", "message_id": "m-1"}, + ) + + # First delivery fails transiently; the dedupe key must be released. + await bus.publish_inbound(inbound) + await _wait_for(lambda: attempts["n"] == 1 and len(outbound_received) >= 1) + + # Provider redelivers the same message_id: it must be reprocessed, not dropped. + await bus.publish_inbound(inbound) + await _wait_for(lambda: attempts["n"] == 2) + await asyncio.sleep(0.05) + await manager.stop() + + assert attempts["n"] == 2 + + _run(go()) + def test_handle_chat_outbound_preserves_inbound_metadata(self): """DingTalk (and similar) need inbound metadata on outbound sends (e.g. sender_staff_id).""" from app.channels.manager import ChannelManager @@ -3752,7 +3872,7 @@ class TestWeComChannel: assert inbound.thread_ts == "msg-1" assert inbound.topic_id == "user-1" assert inbound.files == files - assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single"} + assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single", "message_id": "msg-1"} assert channel._ws_frames["msg-1"] is frame assert channel._ws_stream_ids["msg-1"] == "stream-1" diff --git a/backend/tests/test_slack_channel_connections.py b/backend/tests/test_slack_channel_connections.py index 5b718bb3c..bb8af9f60 100644 --- a/backend/tests/test_slack_channel_connections.py +++ b/backend/tests/test_slack_channel_connections.py @@ -98,24 +98,13 @@ def test_slack_send_uses_connection_bot_token_when_connection_id_is_present(): anyio.run(go) -def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch): +def test_slack_http_events_mode_is_rejected(monkeypatch, caplog): import anyio from app.channels.slack import SlackChannel - class FakeWebClient: - def __init__(self, token: str) -> None: - self.token = token - self.messages: list[dict] = [] - - def auth_test(self): - return {"user_id": "B-http"} - - def chat_postMessage(self, **kwargs): - self.messages.append(kwargs) - slack_sdk = ModuleType("slack_sdk") - slack_sdk.WebClient = FakeWebClient + slack_sdk.WebClient = object socket_mode = ModuleType("slack_sdk.socket_mode") socket_mode.SocketModeClient = object response = ModuleType("slack_sdk.socket_mode.response") @@ -129,26 +118,20 @@ def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch): bus=MessageBus(), config={ "bot_token": "xoxb-operator", + # Provide app_token too so the missing-token early return cannot + # fire before the HTTP-mode guard — otherwise the state assertions + # below would hold even if the guard were deleted. + "app_token": "xapp-token", "event_delivery": "http", "connection_repo": MagicMock(), }, ) - await channel.start() - assert channel._running is True - assert channel._web_client is not None - assert channel._web_client.token == "xoxb-operator" - assert channel._bot_user_id == "B-http" + with caplog.at_level("ERROR", logger="app.channels.slack"): + await channel.start() - await channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100") - - assert channel._web_client.messages == [ - { - "channel": "C123", - "text": "Slack connected to DeerFlow.", - "thread_ts": "1710000000.000100", - } - ] - await channel.stop() + assert channel._running is False + assert channel._web_client is None + assert "Slack HTTP Events mode is not supported" in caplog.text anyio.run(go)