fix(channels): add operational guardrails (#3584)

* fix(channels): add operational guardrails

* make format

* fix(channels): converge with #3582 to avoid merge-order conflicts

Drop this PR's DingTalk INFO-log redaction and hand it to #3582, which
already restructures that handler and will redact the same log there. This
PR no longer touches dingtalk.py, so the two PRs can merge to main in any
order without a conflict.

For WeChat, drop the contested thread_ts priority reorder (review #3) and
keep only what inbound dedupe needs: a server-stable message_id in the
inbound metadata (message_id/msg_id, no client_id per review #6). This is a
single added line inside the metadata dict, a region #3582 never touches, so
it auto-merges regardless of order.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): address three correctness review findings

1. Connect-code cap was racy (willem #1): _create_state ran delete-expired,
   count, and insert as three separate transactions, so concurrent connect
   POSTs from one owner could each see count < cap and all insert past it. Add
   ChannelConnectionRepository.create_oauth_state_within_cap which does
   delete+count+insert in a single transaction serialized per (owner,
   provider) — Postgres via pg_advisory_xact_lock, SQLite via the write lock
   the leading DELETE takes — and have the router use it.

2. Inbound dedupe key fell back to "" workspace (willem #3): two workspaces
   delivering without team/guild/aibotid would collapse to the same key and
   dedupe each other's messages. _inbound_dedupe_key now fails closed
   (returns None) when no workspace identifier is present.

3. Dedupe key was recorded on receipt and never released on failure
   (ShenAC #1): a transient error (DB blip, Gateway 503) left the key in place
   for the full TTL, so a provider redelivery of the same message_id — exactly
   the retry dedupe should absorb — was silently dropped. _handle_message now
   releases the key in the unexpected-exception branch so redelivery can
   recover, while keeping record-on-receipt so retries during handling are
   still deduped.

Tests: repo cap enforcement incl. concurrent-issuance non-leak; dedupe
fail-closed; dedupe key release-on-failure redelivery recovery.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* fix(channels): address cleanup/efficiency and test review findings

Efficiency / cleanup:
- Dedupe key set drops client-generated ids (client_msg_id, client_id);
  keep only server-stable event_id/message_id/msg_id, which a provider's own
  redelivery preserves (ShenAC #6). Every provider already emits message_id.
- TTL/overflow pruning of _recent_inbound_events is now O(k): switch to an
  OrderedDict and popitem(last=False) from the front instead of scanning all
  4096 entries on every inbound (willem #4).
- Log "received inbound" only after the dedupe check so a provider retrying N
  times no longer logs N accepts; document that manager dedupe covers the
  agent run/final answer, not provider ack side-effects (willem #5, ShenAC #2).
- Slack drops the redundant `team_id or event.get("team")` fallback the caller
  already resolved (willem #6).
- create_oauth_state_within_cap prunes only this owner/provider's expired codes
  instead of a global DELETE on every connect POST; global cleanup still runs
  on consume_oauth_state (willem #7).

Tests:
- Dedupe test uses tmp_path instead of a leaked mkdtemp, uses distinct objects
  per publish, and adds a negative control: a different message_id is still
  processed, catching over-dedupe regressions (willem #8, ShenAC #4).
- Slack HTTP-mode rejection test supplies app_token so the missing-token early
  return can't mask the guard, giving the state assertions teeth (ShenAC #3).
- count_oauth_states test pins that the active row survives, not just the count
  (ShenAC #5).

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>

* make format

---------

Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
Nan Gao
2026-06-18 04:09:46 +02:00
committed by GitHub
parent 97dd9ecf73
commit 8c0830aea1
12 changed files with 468 additions and 51 deletions
+2 -2
View File
@@ -837,14 +837,14 @@ class FeishuChannel(Channel):
text = text.strip() text = text.strip()
logger.info( logger.info(
"[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text=%r", "[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text_len=%d",
chat_id, chat_id,
msg_id, msg_id,
root_id, root_id,
parent_id, parent_id,
feishu_thread_id, feishu_thread_id,
sender_id, sender_id,
text[:100] if text else "", len(text or ""),
) )
if not (text or files_list): if not (text or files_list):
+97 -4
View File
@@ -7,6 +7,7 @@ import logging
import mimetypes import mimetypes
import re import re
import time import time
from collections import OrderedDict
from collections.abc import Awaitable, Callable, Mapping from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
@@ -62,6 +63,12 @@ MESSAGE_STREAM_EVENTS = ("messages-tuple", "messages")
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again." THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
BOUND_IDENTITY_REQUIRED_MESSAGE = "Connect this channel from DeerFlow Settings, complete the in-channel connect step, then send your message again." BOUND_IDENTITY_REQUIRED_MESSAGE = "Connect this channel from DeerFlow Settings, complete the in-channel connect step, then send your message again."
BOUND_IDENTITY_UNAVAILABLE_MESSAGE = "Channel connection verification is temporarily unavailable. Please try again later or contact the DeerFlow operator." BOUND_IDENTITY_UNAVAILABLE_MESSAGE = "Channel connection verification is temporarily unavailable. Please try again later or contact the DeerFlow operator."
INBOUND_DEDUPE_TTL_SECONDS = 10 * 60
INBOUND_DEDUPE_MAX_ENTRIES = 4096
# Only server-stable provider message ids: client-generated ids (client_msg_id,
# client_id) are not guaranteed identical across a provider's own redelivery, so
# keying dedupe on them would miss exactly the retries we want to absorb.
INBOUND_DEDUPE_METADATA_KEYS = ("event_id", "message_id", "msg_id")
CHANNEL_CAPABILITIES = { CHANNEL_CAPABILITIES = {
"dingtalk": {"supports_streaming": False}, "dingtalk": {"supports_streaming": False},
@@ -774,6 +781,10 @@ class ChannelManager:
self._semaphore: asyncio.Semaphore | None = None self._semaphore: asyncio.Semaphore | None = None
self._running = False self._running = False
self._task: asyncio.Task | None = None self._task: asyncio.Task | None = None
# Insertion order == chronological (keys are never re-inserted), so an
# OrderedDict lets us evict expired/overflow entries from the front in
# O(k) instead of scanning all entries on every inbound message.
self._recent_inbound_events: OrderedDict[tuple[str, str, str, str], float] = OrderedDict()
@staticmethod @staticmethod
def _channel_supports_streaming(channel_name: str) -> bool: def _channel_supports_streaming(channel_name: str) -> bool:
@@ -919,16 +930,94 @@ class ChannelManager:
except asyncio.CancelledError: except asyncio.CancelledError:
break break
# Dedupe before logging "received" so a provider retrying an event N
# times does not log N accepts; duplicates are logged once as ignored.
# Note: this manager-level dedupe only guards the agent run / final
# answer. Provider adapters may emit ack side-effects (a "Working on
# it…" reply, an "eyes" reaction) before publish_inbound, so those are
# intentionally not deduped here.
if self._is_duplicate_inbound(msg):
continue
logger.info( logger.info(
"[Manager] received inbound: channel=%s, chat_id=%s, type=%s, text=%r", "[Manager] received inbound: channel=%s, chat_id=%s, type=%s, text_len=%d, files=%d",
msg.channel_name, msg.channel_name,
msg.chat_id, msg.chat_id,
msg.msg_type.value, msg.msg_type.value,
msg.text[:100] if msg.text else "", len(msg.text or ""),
len(msg.files),
) )
task = asyncio.create_task(self._handle_message(msg)) task = asyncio.create_task(self._handle_message(msg))
task.add_done_callback(self._log_task_error) task.add_done_callback(self._log_task_error)
@staticmethod
def _inbound_dedupe_key(msg: InboundMessage) -> tuple[str, str, str, str] | None:
metadata = msg.metadata or {}
message_id = None
for key in INBOUND_DEDUPE_METADATA_KEYS:
value = metadata.get(key)
if value:
message_id = str(value)
break
if message_id is None:
raw_message = metadata.get("raw_message")
if isinstance(raw_message, Mapping):
for key in INBOUND_DEDUPE_METADATA_KEYS:
value = raw_message.get(key)
if value:
message_id = str(value)
break
if message_id is None:
return None
# Fail closed: without a workspace/team/guild identifier we cannot tell two
# workspaces apart (e.g. Slack channel ids are not globally unique), so
# skip dedupe rather than risk collapsing distinct workspaces' messages.
workspace_id = msg.workspace_id or metadata.get("workspace_id") or metadata.get("team_id") or metadata.get("guild_id") or metadata.get("aibotid")
if not workspace_id:
return None
return (msg.channel_name, str(workspace_id), msg.chat_id, message_id)
def _is_duplicate_inbound(self, msg: InboundMessage) -> bool:
key = self._inbound_dedupe_key(msg)
if key is None:
return False
now = time.monotonic()
# Entries are in chronological insertion order, so expired ones cluster at
# the front: pop from the front until we hit a still-live entry.
while self._recent_inbound_events:
_, oldest_at = next(iter(self._recent_inbound_events.items()))
if now - oldest_at > INBOUND_DEDUPE_TTL_SECONDS:
self._recent_inbound_events.popitem(last=False)
else:
break
while len(self._recent_inbound_events) > INBOUND_DEDUPE_MAX_ENTRIES:
self._recent_inbound_events.popitem(last=False)
if key in self._recent_inbound_events:
logger.info(
"[Manager] duplicate inbound ignored: channel=%s, chat_id=%s, message_id=%s",
msg.channel_name,
msg.chat_id,
key[-1],
)
return True
self._recent_inbound_events[key] = now
return False
def _release_inbound_dedupe_key(self, msg: InboundMessage) -> None:
"""Drop a recorded dedupe key so a provider redelivery can be reprocessed.
Called only on transient/unexpected handling failures: the key was
recorded on receipt so retries arriving *while* the message is being
handled are still deduped, but if handling fails we must not turn a
recoverable error into a TTL-long black hole for the same message_id.
"""
key = self._inbound_dedupe_key(msg)
if key is not None:
self._recent_inbound_events.pop(key, None)
@staticmethod @staticmethod
def _log_task_error(task: asyncio.Task) -> None: def _log_task_error(task: asyncio.Task) -> None:
"""Surface unhandled exceptions from background tasks.""" """Surface unhandled exceptions from background tasks."""
@@ -979,6 +1068,10 @@ class ChannelManager:
msg.channel_name, msg.channel_name,
msg.chat_id, msg.chat_id,
) )
# Transient/unexpected failure: release the dedupe key so a provider
# redelivery of the same message can recover instead of being dropped
# for the dedupe TTL.
self._release_inbound_dedupe_key(msg)
await self._send_error(msg, "An internal error occurred. Please try again.") await self._send_error(msg, "An internal error occurred. Please try again.")
# -- chat handling ----------------------------------------------------- # -- chat handling -----------------------------------------------------
@@ -1169,7 +1262,7 @@ class ChannelManager:
) )
return return
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100]) logger.info("[Manager] invoking runs.wait(thread_id=%s, text_len=%d)", thread_id, len(msg.text or ""))
run_kwargs: dict[str, Any] = { run_kwargs: dict[str, Any] = {
"input": {"messages": [human_message]}, "input": {"messages": [human_message]},
"config": run_config, "config": run_config,
@@ -1236,7 +1329,7 @@ class ChannelManager:
run_context: dict[str, Any], run_context: dict[str, Any],
human_message: dict[str, Any], human_message: dict[str, Any],
) -> None: ) -> None:
logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100]) logger.info("[Manager] invoking runs.stream(thread_id=%s, text_len=%d)", thread_id, len(msg.text or ""))
last_values: dict[str, Any] | list | None = None last_values: dict[str, Any] | list | None = None
streamed_buffers: dict[str, str] = {} streamed_buffers: dict[str, str] = {}
+9 -10
View File
@@ -90,15 +90,8 @@ class SlackChannel(Channel):
bot_token = self.config.get("bot_token", "") bot_token = self.config.get("bot_token", "")
app_token = self.config.get("app_token", "") app_token = self.config.get("app_token", "")
if self._connection_repo is not None and self.config.get("event_delivery") == "http": if self.config.get("event_delivery") == "http":
if not bot_token: logger.error("Slack HTTP Events mode is not supported by this channel adapter; use Socket Mode with app_token")
logger.error("Slack HTTP Events mode requires bot_token")
return
await self._initialize_operator_web_client(str(bot_token))
self._loop = asyncio.get_event_loop()
self._running = True
self.bus.subscribe_outbound(self._on_outbound)
logger.info("Slack channel started in HTTP Events mode")
return return
if not bot_token or not app_token: if not bot_token or not app_token:
@@ -319,7 +312,7 @@ class SlackChannel(Channel):
asyncio.run_coroutine_threadsafe( asyncio.run_coroutine_threadsafe(
self._bind_connection_from_connect_code( self._bind_connection_from_connect_code(
event=event, event=event,
team_id=str(team_id or event.get("team") or ""), team_id=str(team_id or ""),
code=connect_code, code=connect_code,
), ),
self._loop, self._loop,
@@ -343,6 +336,12 @@ class SlackChannel(Channel):
text=text, text=text,
msg_type=msg_type, msg_type=msg_type,
thread_ts=thread_ts, thread_ts=thread_ts,
metadata={
# team_id is already resolved (payload team_id/team, else event team) by the caller.
"team_id": team_id,
"message_id": event.get("ts"),
"client_msg_id": event.get("client_msg_id"),
},
) )
inbound.topic_id = thread_ts inbound.topic_id = thread_ts
+2
View File
@@ -503,6 +503,7 @@ class TelegramChannel(Channel):
text=text, text=text,
msg_type=InboundMessageType.COMMAND, msg_type=InboundMessageType.COMMAND,
thread_ts=msg_id, thread_ts=msg_id,
metadata={"message_id": msg_id},
) )
inbound.topic_id = topic_id inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound) inbound = await self._attach_connection_identity(inbound)
@@ -546,6 +547,7 @@ class TelegramChannel(Channel):
text=text, text=text,
msg_type=InboundMessageType.CHAT, msg_type=InboundMessageType.CHAT,
thread_ts=msg_id, thread_ts=msg_id,
metadata={"message_id": msg_id},
) )
inbound.topic_id = topic_id inbound.topic_id = topic_id
inbound = await self._attach_connection_identity(inbound) inbound = await self._attach_connection_identity(inbound)
+1
View File
@@ -627,6 +627,7 @@ class WechatChannel(Channel):
metadata={ metadata={
"context_token": context_token, "context_token": context_token,
"ilink_user_id": chat_id, "ilink_user_id": chat_id,
"message_id": str(raw_message.get("message_id") or raw_message.get("msg_id") or "").strip(),
"ref_msg": self._extract_ref_message(raw_message), "ref_msg": self._extract_ref_message(raw_message),
"raw_message": raw_message, "raw_message": raw_message,
}, },
+5 -1
View File
@@ -313,7 +313,11 @@ class WeComChannel(Channel):
msg_type=inbound_type, msg_type=inbound_type,
thread_ts=msg_id, thread_ts=msg_id,
files=files or [], files=files or [],
metadata={"aibotid": body.get("aibotid"), "chattype": body.get("chattype")}, metadata={
"aibotid": body.get("aibotid"),
"chattype": body.get("chattype"),
"message_id": msg_id,
},
) )
inbound.topic_id = user_id # keep the same thread inbound.topic_id = user_id # keep the same thread
@@ -24,6 +24,7 @@ router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_STATE_TTL_SECONDS = 600 _STATE_TTL_SECONDS = 600
_MAX_PENDING_CONNECT_CODES_PER_PROVIDER = 5
_MASKED_CREDENTIAL_VALUE = "********" _MASKED_CREDENTIAL_VALUE = "********"
@@ -332,13 +333,23 @@ async def _create_state(
owner_user_id: str, owner_user_id: str,
provider: str, provider: str,
) -> str: ) -> str:
now = datetime.now(UTC)
state = _new_binding_code() state = _new_binding_code()
await repo.create_oauth_state( # Atomic delete-expired + count + insert so concurrent connect POSTs from one
# owner cannot each see count < cap and all insert past the cap.
inserted = await repo.create_oauth_state_within_cap(
owner_user_id=owner_user_id, owner_user_id=owner_user_id,
provider=provider, provider=provider,
state=state, state=state,
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS), expires_at=now + timedelta(seconds=_STATE_TTL_SECONDS),
max_pending=_MAX_PENDING_CONNECT_CODES_PER_PROVIDER,
now=now,
) )
if not inserted:
raise HTTPException(
status_code=429,
detail="Too many pending channel connection codes. Wait for existing codes to expire or use one of them.",
)
return state return state
@@ -11,7 +11,7 @@ from datetime import UTC, datetime
from typing import Any from typing import Any
from cryptography.fernet import Fernet, InvalidToken from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import delete, func, select, update from sqlalchemy import delete, func, select, text, update
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -279,16 +279,128 @@ class ChannelConnectionRepository:
session.add(row) session.add(row)
await session.commit() await session.commit()
async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int: async def create_oauth_state_within_cap(
self,
*,
owner_user_id: str,
provider: str,
state: str,
expires_at: datetime,
max_pending: int,
now: datetime | None = None,
code_verifier: str | None = None,
nonce_hash: str | None = None,
redirect_after: str | None = None,
requested_scopes: list[str] | None = None,
metadata: dict[str, Any] | None = None,
) -> bool:
"""Atomically enforce the per-(owner, provider) pending cap, then insert.
delete-expired + count + insert run in a single transaction serialized
per (owner, provider), so concurrent connect requests cannot each
observe ``count < max_pending`` and all insert (which would leak past
the cap). PostgreSQL takes a transaction-scoped advisory lock; SQLite
serializes writers through the write lock the leading DELETE acquires.
Returns ``True`` when the row was inserted, ``False`` when the cap is
already reached.
"""
current_time = now or datetime.now(UTC)
async with self.session_factory() as session: async with self.session_factory() as session:
result = await session.execute( await self._serialize_oauth_owner_scope(session, owner_user_id, provider)
# Prune only this owner/provider's expired codes (the ones that affect
# this cap), not every user's — avoids a global DELETE on each connect
# POST. Issuing this write first also takes the SQLite database write
# lock so the count below cannot race a concurrent inserter between
# count and commit. Stale codes for other owners are pruned globally
# by consume_oauth_state / delete_expired_oauth_states.
await session.execute(
delete(ChannelOAuthStateRow).where(
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
ChannelOAuthStateRow.expires_at < current_time,
)
)
pending = await session.execute(
select(func.count()) select(func.count())
.select_from(ChannelOAuthStateRow) .select_from(ChannelOAuthStateRow)
.where( .where(
ChannelOAuthStateRow.owner_user_id == owner_user_id, ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider, ChannelOAuthStateRow.provider == provider,
ChannelOAuthStateRow.consumed_at.is_(None),
ChannelOAuthStateRow.expires_at >= current_time,
) )
) )
if int(pending.scalar_one()) >= max_pending:
await session.rollback()
return False
session.add(
ChannelOAuthStateRow(
state_hash=self.hash_state(state),
owner_user_id=owner_user_id,
provider=provider,
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
nonce_hash=nonce_hash,
redirect_after=redirect_after,
requested_scopes_json=list(requested_scopes or []),
metadata_json=dict(metadata or {}),
expires_at=expires_at,
)
)
await session.commit()
return True
async def _serialize_oauth_owner_scope(self, session: AsyncSession, owner_user_id: str, provider: str) -> None:
"""Serialize concurrent pending-cap transactions for one (owner, provider).
On PostgreSQL this takes a transaction-scoped advisory lock so concurrent
issuers run their count+insert one at a time. On SQLite the leading
DELETE in the caller's transaction already acquires the database write
lock, which serializes writers, so no extra lock is required.
"""
try:
dialect = session.bind.dialect.name if session.bind is not None else ""
except Exception:
dialect = ""
if dialect == "postgresql":
await session.execute(text("SELECT pg_advisory_xact_lock(:lock_key)"), {"lock_key": self._oauth_scope_lock_key(owner_user_id, provider)})
@staticmethod
def _oauth_scope_lock_key(owner_user_id: str, provider: str) -> int:
digest = hashlib.sha256(f"{owner_user_id}\x00{provider}".encode()).digest()
# 63-bit non-negative key for pg_advisory_xact_lock(bigint).
return int.from_bytes(digest[:8], "big") & 0x7FFFFFFFFFFFFFFF
async def delete_expired_oauth_states(self, *, now: datetime | None = None) -> int:
current_time = now or datetime.now(UTC)
async with self.session_factory() as session:
result = await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
await session.commit()
return int(result.rowcount or 0)
async def count_oauth_states(
self,
*,
owner_user_id: str,
provider: str,
active_only: bool = False,
now: datetime | None = None,
) -> int:
current_time = now or datetime.now(UTC)
conditions = [
ChannelOAuthStateRow.owner_user_id == owner_user_id,
ChannelOAuthStateRow.provider == provider,
]
if active_only:
conditions.extend(
[
ChannelOAuthStateRow.consumed_at.is_(None),
ChannelOAuthStateRow.expires_at >= current_time,
]
)
async with self.session_factory() as session:
result = await session.execute(select(func.count()).select_from(ChannelOAuthStateRow).where(*conditions))
return int(result.scalar_one()) return int(result.scalar_one())
async def consume_oauth_state( async def consume_oauth_state(
@@ -246,6 +246,77 @@ class TestChannelConnectionRepository:
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all() states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
assert [state.state_hash for state in states] == [repo.hash_state("active-state")] assert [state.state_hash for state in states] == [repo.hash_state("active-state")]
@pytest.mark.anyio
async def test_count_oauth_states_active_only_and_delete_expired(self, repo):
now = datetime.now(UTC)
await repo.create_oauth_state(
owner_user_id="alice",
provider="slack",
state="expired-state",
expires_at=now - timedelta(minutes=1),
)
await repo.create_oauth_state(
owner_user_id="alice",
provider="slack",
state="active-state",
expires_at=now + timedelta(minutes=5),
)
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 1
assert await repo.delete_expired_oauth_states(now=now) == 1
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack") == 1
# Pin that the surviving row is the active one (an inverted expiry
# predicate would delete the active row, still return 1, and pass above).
async with repo.session_factory() as session:
survivors = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
assert [row.state_hash for row in survivors] == [repo.hash_state("active-state")]
@pytest.mark.anyio
async def test_create_oauth_state_within_cap_enforces_pending_cap(self, repo):
now = datetime.now(UTC)
expires = now + timedelta(minutes=5)
for i in range(3):
inserted = await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state=f"code-{i}", expires_at=expires, max_pending=3, now=now)
assert inserted is True
# Cap reached: the next issuance is rejected and nothing is inserted.
assert await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state="code-over", expires_at=expires, max_pending=3, now=now) is False
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 3
# Expired rows are pruned and free up capacity; a different owner is unaffected.
assert await repo.create_oauth_state_within_cap(owner_user_id="bob", provider="slack", state="bob-1", expires_at=expires, max_pending=3, now=now) is True
@pytest.mark.anyio
async def test_create_oauth_state_within_cap_ignores_expired_rows(self, repo):
now = datetime.now(UTC)
# Three already-expired rows must not count against the cap.
for i in range(3):
await repo.create_oauth_state(owner_user_id="alice", provider="slack", state=f"old-{i}", expires_at=now - timedelta(minutes=1))
inserted = await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state="fresh", expires_at=now + timedelta(minutes=5), max_pending=3, now=now)
assert inserted is True
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 1
@pytest.mark.anyio
async def test_create_oauth_state_within_cap_does_not_leak_under_concurrency(self, repo):
"""Concurrent issuance for one owner cannot push past the cap (willem #1)."""
import anyio
now = datetime.now(UTC)
expires = now + timedelta(minutes=5)
results: list[bool] = []
async def issue(state: str) -> None:
results.append(await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state=state, expires_at=expires, max_pending=3, now=now))
async with anyio.create_task_group() as tg:
for i in range(8):
tg.start_soon(issue, f"code-{i}")
assert sum(1 for ok in results if ok) == 3
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 3
@pytest.mark.anyio @pytest.mark.anyio
async def test_consume_oauth_state_is_one_time_even_under_concurrent_consumers(self, repo): async def test_consume_oauth_state_is_one_time_even_under_concurrent_consumers(self, repo):
import anyio import anyio
@@ -504,6 +504,27 @@ def test_connect_slack_returns_binding_command_and_persists_state(tmp_path):
anyio.run(repo.close) anyio.run(repo.close)
def test_connect_binding_code_caps_pending_states_per_provider(tmp_path):
import anyio
repo = anyio.run(_make_repo, tmp_path)
app = _make_app(_enabled_connections_config(), repo, _channels_config())
with TestClient(app) as client:
responses = [client.post("/api/channels/slack/connect") for _ in range(6)]
assert [response.status_code for response in responses[:5]] == [200, 200, 200, 200, 200]
assert responses[5].status_code == 429
assert "Too many pending channel connection codes" in responses[5].json()["detail"]
async def count_states():
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="slack")
assert anyio.run(count_states) == 5
anyio.run(repo.close)
def test_connect_discord_returns_binding_command_and_persists_state(tmp_path): def test_connect_discord_returns_binding_command_and_persists_state(tmp_path):
import anyio import anyio
+121 -1
View File
@@ -800,6 +800,126 @@ class TestChannelManager:
_run(go()) _run(go())
def test_dispatch_loop_dedupes_stable_provider_message_id(self, tmp_path):
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=tmp_path / "store.json")
manager = ChannelManager(bus=bus, store=store)
manager._client = _make_mock_langgraph_client()
outbound_received: list[OutboundMessage] = []
async def capture_outbound(msg: OutboundMessage) -> None:
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
await manager.start()
def _slack_inbound(message_id: str) -> InboundMessage:
# Distinct objects per publish, like a real provider redelivery.
return InboundMessage(
channel_name="slack",
chat_id="C123",
user_id="U123",
text="sensitive prompt",
topic_id="1710000000.000100",
metadata={"team_id": "T123", "message_id": message_id},
)
# Same stable message_id delivered twice -> processed once.
await bus.publish_inbound(_slack_inbound("1710000000.000200"))
await bus.publish_inbound(_slack_inbound("1710000000.000200"))
await _wait_for(lambda: manager._client.runs.wait.call_count == 1 and len(outbound_received) == 1)
await asyncio.sleep(0.05)
assert manager._client.threads.create.call_count == 1
assert manager._client.runs.wait.call_count == 1
assert len(outbound_received) == 1
# Negative control: a *different* message_id must still be processed,
# so an over-dedupe regression (dropping distinct messages) is caught.
await bus.publish_inbound(_slack_inbound("1710000000.000999"))
await _wait_for(lambda: manager._client.runs.wait.call_count == 2 and len(outbound_received) == 2)
await asyncio.sleep(0.05)
await manager.stop()
assert manager._client.runs.wait.call_count == 2
assert len(outbound_received) == 2
_run(go())
def test_inbound_dedupe_key_fails_closed_without_workspace(self):
"""Without a workspace identifier, skip dedupe instead of collapsing workspaces (willem #3)."""
from app.channels.manager import ChannelManager
with_workspace = InboundMessage(
channel_name="slack",
chat_id="C1",
user_id="U1",
text="x",
metadata={"team_id": "T1", "message_id": "m1"},
)
assert ChannelManager._inbound_dedupe_key(with_workspace) == ("slack", "T1", "C1", "m1")
without_workspace = InboundMessage(
channel_name="slack",
chat_id="C1",
user_id="U1",
text="x",
metadata={"message_id": "m1"},
)
assert ChannelManager._inbound_dedupe_key(without_workspace) is None
def test_dispatch_loop_releases_dedupe_key_when_handling_fails(self, tmp_path):
"""A transient handling failure must not black-hole a provider redelivery (ShenAC #1)."""
from app.channels.manager import ChannelManager
async def go():
bus = MessageBus()
store = ChannelStore(path=tmp_path / "store.json")
manager = ChannelManager(bus=bus, store=store)
client = _make_mock_langgraph_client()
attempts = {"n": 0}
async def flaky_wait(*args, **kwargs):
attempts["n"] += 1
if attempts["n"] == 1:
raise RuntimeError("transient gateway 503")
return {"messages": [{"type": "human", "content": "hi"}, {"type": "ai", "content": "recovered"}]}
client.runs.wait = AsyncMock(side_effect=flaky_wait)
manager._client = client
outbound_received: list[OutboundMessage] = []
async def capture_outbound(msg: OutboundMessage) -> None:
outbound_received.append(msg)
bus.subscribe_outbound(capture_outbound)
await manager.start()
inbound = InboundMessage(
channel_name="slack",
chat_id="C123",
user_id="U123",
text="hello",
metadata={"team_id": "T123", "message_id": "m-1"},
)
# First delivery fails transiently; the dedupe key must be released.
await bus.publish_inbound(inbound)
await _wait_for(lambda: attempts["n"] == 1 and len(outbound_received) >= 1)
# Provider redelivers the same message_id: it must be reprocessed, not dropped.
await bus.publish_inbound(inbound)
await _wait_for(lambda: attempts["n"] == 2)
await asyncio.sleep(0.05)
await manager.stop()
assert attempts["n"] == 2
_run(go())
def test_handle_chat_outbound_preserves_inbound_metadata(self): def test_handle_chat_outbound_preserves_inbound_metadata(self):
"""DingTalk (and similar) need inbound metadata on outbound sends (e.g. sender_staff_id).""" """DingTalk (and similar) need inbound metadata on outbound sends (e.g. sender_staff_id)."""
from app.channels.manager import ChannelManager from app.channels.manager import ChannelManager
@@ -3752,7 +3872,7 @@ class TestWeComChannel:
assert inbound.thread_ts == "msg-1" assert inbound.thread_ts == "msg-1"
assert inbound.topic_id == "user-1" assert inbound.topic_id == "user-1"
assert inbound.files == files assert inbound.files == files
assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single"} assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single", "message_id": "msg-1"}
assert channel._ws_frames["msg-1"] is frame assert channel._ws_frames["msg-1"] is frame
assert channel._ws_stream_ids["msg-1"] == "stream-1" assert channel._ws_stream_ids["msg-1"] == "stream-1"
+11 -28
View File
@@ -98,24 +98,13 @@ def test_slack_send_uses_connection_bot_token_when_connection_id_is_present():
anyio.run(go) anyio.run(go)
def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch): def test_slack_http_events_mode_is_rejected(monkeypatch, caplog):
import anyio import anyio
from app.channels.slack import SlackChannel from app.channels.slack import SlackChannel
class FakeWebClient:
def __init__(self, token: str) -> None:
self.token = token
self.messages: list[dict] = []
def auth_test(self):
return {"user_id": "B-http"}
def chat_postMessage(self, **kwargs):
self.messages.append(kwargs)
slack_sdk = ModuleType("slack_sdk") slack_sdk = ModuleType("slack_sdk")
slack_sdk.WebClient = FakeWebClient slack_sdk.WebClient = object
socket_mode = ModuleType("slack_sdk.socket_mode") socket_mode = ModuleType("slack_sdk.socket_mode")
socket_mode.SocketModeClient = object socket_mode.SocketModeClient = object
response = ModuleType("slack_sdk.socket_mode.response") response = ModuleType("slack_sdk.socket_mode.response")
@@ -129,26 +118,20 @@ def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch):
bus=MessageBus(), bus=MessageBus(),
config={ config={
"bot_token": "xoxb-operator", "bot_token": "xoxb-operator",
# Provide app_token too so the missing-token early return cannot
# fire before the HTTP-mode guard — otherwise the state assertions
# below would hold even if the guard were deleted.
"app_token": "xapp-token",
"event_delivery": "http", "event_delivery": "http",
"connection_repo": MagicMock(), "connection_repo": MagicMock(),
}, },
) )
await channel.start() with caplog.at_level("ERROR", logger="app.channels.slack"):
assert channel._running is True await channel.start()
assert channel._web_client is not None
assert channel._web_client.token == "xoxb-operator"
assert channel._bot_user_id == "B-http"
await channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100") assert channel._running is False
assert channel._web_client is None
assert channel._web_client.messages == [ assert "Slack HTTP Events mode is not supported" in caplog.text
{
"channel": "C123",
"text": "Slack connected to DeerFlow.",
"thread_ts": "1710000000.000100",
}
]
await channel.stop()
anyio.run(go) anyio.run(go)