mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-18 05:25:57 +00:00
fix(channels): add operational guardrails (#3584)
* fix(channels): add operational guardrails * make format * fix(channels): converge with #3582 to avoid merge-order conflicts Drop this PR's DingTalk INFO-log redaction and hand it to #3582, which already restructures that handler and will redact the same log there. This PR no longer touches dingtalk.py, so the two PRs can merge to main in any order without a conflict. For WeChat, drop the contested thread_ts priority reorder (review #3) and keep only what inbound dedupe needs: a server-stable message_id in the inbound metadata (message_id/msg_id, no client_id per review #6). This is a single added line inside the metadata dict, a region #3582 never touches, so it auto-merges regardless of order. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): address three correctness review findings 1. Connect-code cap was racy (willem #1): _create_state ran delete-expired, count, and insert as three separate transactions, so concurrent connect POSTs from one owner could each see count < cap and all insert past it. Add ChannelConnectionRepository.create_oauth_state_within_cap which does delete+count+insert in a single transaction serialized per (owner, provider) — Postgres via pg_advisory_xact_lock, SQLite via the write lock the leading DELETE takes — and have the router use it. 2. Inbound dedupe key fell back to "" workspace (willem #3): two workspaces delivering without team/guild/aibotid would collapse to the same key and dedupe each other's messages. _inbound_dedupe_key now fails closed (returns None) when no workspace identifier is present. 3. Dedupe key was recorded on receipt and never released on failure (ShenAC #1): a transient error (DB blip, Gateway 503) left the key in place for the full TTL, so a provider redelivery of the same message_id — exactly the retry dedupe should absorb — was silently dropped. _handle_message now releases the key in the unexpected-exception branch so redelivery can recover, while keeping record-on-receipt so retries during handling are still deduped. Tests: repo cap enforcement incl. concurrent-issuance non-leak; dedupe fail-closed; dedupe key release-on-failure redelivery recovery. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): address cleanup/efficiency and test review findings Efficiency / cleanup: - Dedupe key set drops client-generated ids (client_msg_id, client_id); keep only server-stable event_id/message_id/msg_id, which a provider's own redelivery preserves (ShenAC #6). Every provider already emits message_id. - TTL/overflow pruning of _recent_inbound_events is now O(k): switch to an OrderedDict and popitem(last=False) from the front instead of scanning all 4096 entries on every inbound (willem #4). - Log "received inbound" only after the dedupe check so a provider retrying N times no longer logs N accepts; document that manager dedupe covers the agent run/final answer, not provider ack side-effects (willem #5, ShenAC #2). - Slack drops the redundant `team_id or event.get("team")` fallback the caller already resolved (willem #6). - create_oauth_state_within_cap prunes only this owner/provider's expired codes instead of a global DELETE on every connect POST; global cleanup still runs on consume_oauth_state (willem #7). Tests: - Dedupe test uses tmp_path instead of a leaked mkdtemp, uses distinct objects per publish, and adds a negative control: a different message_id is still processed, catching over-dedupe regressions (willem #8, ShenAC #4). - Slack HTTP-mode rejection test supplies app_token so the missing-token early return can't mask the guard, giving the state assertions teeth (ShenAC #3). - count_oauth_states test pins that the active row survives, not just the count (ShenAC #5). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * make format --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
This commit is contained in:
@@ -837,14 +837,14 @@ class FeishuChannel(Channel):
|
||||
text = text.strip()
|
||||
|
||||
logger.info(
|
||||
"[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text=%r",
|
||||
"[Feishu] parsed message: chat_id=%s, msg_id=%s, root_id=%s, parent_id=%s, thread_id=%s, sender=%s, text_len=%d",
|
||||
chat_id,
|
||||
msg_id,
|
||||
root_id,
|
||||
parent_id,
|
||||
feishu_thread_id,
|
||||
sender_id,
|
||||
text[:100] if text else "",
|
||||
len(text or ""),
|
||||
)
|
||||
|
||||
if not (text or files_list):
|
||||
|
||||
@@ -7,6 +7,7 @@ import logging
|
||||
import mimetypes
|
||||
import re
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Awaitable, Callable, Mapping
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
@@ -62,6 +63,12 @@ MESSAGE_STREAM_EVENTS = ("messages-tuple", "messages")
|
||||
THREAD_BUSY_MESSAGE = "This conversation is already processing another request. Please wait for it to finish and try again."
|
||||
BOUND_IDENTITY_REQUIRED_MESSAGE = "Connect this channel from DeerFlow Settings, complete the in-channel connect step, then send your message again."
|
||||
BOUND_IDENTITY_UNAVAILABLE_MESSAGE = "Channel connection verification is temporarily unavailable. Please try again later or contact the DeerFlow operator."
|
||||
INBOUND_DEDUPE_TTL_SECONDS = 10 * 60
|
||||
INBOUND_DEDUPE_MAX_ENTRIES = 4096
|
||||
# Only server-stable provider message ids: client-generated ids (client_msg_id,
|
||||
# client_id) are not guaranteed identical across a provider's own redelivery, so
|
||||
# keying dedupe on them would miss exactly the retries we want to absorb.
|
||||
INBOUND_DEDUPE_METADATA_KEYS = ("event_id", "message_id", "msg_id")
|
||||
|
||||
CHANNEL_CAPABILITIES = {
|
||||
"dingtalk": {"supports_streaming": False},
|
||||
@@ -774,6 +781,10 @@ class ChannelManager:
|
||||
self._semaphore: asyncio.Semaphore | None = None
|
||||
self._running = False
|
||||
self._task: asyncio.Task | None = None
|
||||
# Insertion order == chronological (keys are never re-inserted), so an
|
||||
# OrderedDict lets us evict expired/overflow entries from the front in
|
||||
# O(k) instead of scanning all entries on every inbound message.
|
||||
self._recent_inbound_events: OrderedDict[tuple[str, str, str, str], float] = OrderedDict()
|
||||
|
||||
@staticmethod
|
||||
def _channel_supports_streaming(channel_name: str) -> bool:
|
||||
@@ -919,16 +930,94 @@ class ChannelManager:
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
|
||||
# Dedupe before logging "received" so a provider retrying an event N
|
||||
# times does not log N accepts; duplicates are logged once as ignored.
|
||||
# Note: this manager-level dedupe only guards the agent run / final
|
||||
# answer. Provider adapters may emit ack side-effects (a "Working on
|
||||
# it…" reply, an "eyes" reaction) before publish_inbound, so those are
|
||||
# intentionally not deduped here.
|
||||
if self._is_duplicate_inbound(msg):
|
||||
continue
|
||||
logger.info(
|
||||
"[Manager] received inbound: channel=%s, chat_id=%s, type=%s, text=%r",
|
||||
"[Manager] received inbound: channel=%s, chat_id=%s, type=%s, text_len=%d, files=%d",
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
msg.msg_type.value,
|
||||
msg.text[:100] if msg.text else "",
|
||||
len(msg.text or ""),
|
||||
len(msg.files),
|
||||
)
|
||||
task = asyncio.create_task(self._handle_message(msg))
|
||||
task.add_done_callback(self._log_task_error)
|
||||
|
||||
@staticmethod
|
||||
def _inbound_dedupe_key(msg: InboundMessage) -> tuple[str, str, str, str] | None:
|
||||
metadata = msg.metadata or {}
|
||||
message_id = None
|
||||
for key in INBOUND_DEDUPE_METADATA_KEYS:
|
||||
value = metadata.get(key)
|
||||
if value:
|
||||
message_id = str(value)
|
||||
break
|
||||
if message_id is None:
|
||||
raw_message = metadata.get("raw_message")
|
||||
if isinstance(raw_message, Mapping):
|
||||
for key in INBOUND_DEDUPE_METADATA_KEYS:
|
||||
value = raw_message.get(key)
|
||||
if value:
|
||||
message_id = str(value)
|
||||
break
|
||||
if message_id is None:
|
||||
return None
|
||||
|
||||
# Fail closed: without a workspace/team/guild identifier we cannot tell two
|
||||
# workspaces apart (e.g. Slack channel ids are not globally unique), so
|
||||
# skip dedupe rather than risk collapsing distinct workspaces' messages.
|
||||
workspace_id = msg.workspace_id or metadata.get("workspace_id") or metadata.get("team_id") or metadata.get("guild_id") or metadata.get("aibotid")
|
||||
if not workspace_id:
|
||||
return None
|
||||
return (msg.channel_name, str(workspace_id), msg.chat_id, message_id)
|
||||
|
||||
def _is_duplicate_inbound(self, msg: InboundMessage) -> bool:
|
||||
key = self._inbound_dedupe_key(msg)
|
||||
if key is None:
|
||||
return False
|
||||
|
||||
now = time.monotonic()
|
||||
# Entries are in chronological insertion order, so expired ones cluster at
|
||||
# the front: pop from the front until we hit a still-live entry.
|
||||
while self._recent_inbound_events:
|
||||
_, oldest_at = next(iter(self._recent_inbound_events.items()))
|
||||
if now - oldest_at > INBOUND_DEDUPE_TTL_SECONDS:
|
||||
self._recent_inbound_events.popitem(last=False)
|
||||
else:
|
||||
break
|
||||
while len(self._recent_inbound_events) > INBOUND_DEDUPE_MAX_ENTRIES:
|
||||
self._recent_inbound_events.popitem(last=False)
|
||||
|
||||
if key in self._recent_inbound_events:
|
||||
logger.info(
|
||||
"[Manager] duplicate inbound ignored: channel=%s, chat_id=%s, message_id=%s",
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
key[-1],
|
||||
)
|
||||
return True
|
||||
|
||||
self._recent_inbound_events[key] = now
|
||||
return False
|
||||
|
||||
def _release_inbound_dedupe_key(self, msg: InboundMessage) -> None:
|
||||
"""Drop a recorded dedupe key so a provider redelivery can be reprocessed.
|
||||
|
||||
Called only on transient/unexpected handling failures: the key was
|
||||
recorded on receipt so retries arriving *while* the message is being
|
||||
handled are still deduped, but if handling fails we must not turn a
|
||||
recoverable error into a TTL-long black hole for the same message_id.
|
||||
"""
|
||||
key = self._inbound_dedupe_key(msg)
|
||||
if key is not None:
|
||||
self._recent_inbound_events.pop(key, None)
|
||||
|
||||
@staticmethod
|
||||
def _log_task_error(task: asyncio.Task) -> None:
|
||||
"""Surface unhandled exceptions from background tasks."""
|
||||
@@ -979,6 +1068,10 @@ class ChannelManager:
|
||||
msg.channel_name,
|
||||
msg.chat_id,
|
||||
)
|
||||
# Transient/unexpected failure: release the dedupe key so a provider
|
||||
# redelivery of the same message can recover instead of being dropped
|
||||
# for the dedupe TTL.
|
||||
self._release_inbound_dedupe_key(msg)
|
||||
await self._send_error(msg, "An internal error occurred. Please try again.")
|
||||
|
||||
# -- chat handling -----------------------------------------------------
|
||||
@@ -1169,7 +1262,7 @@ class ChannelManager:
|
||||
)
|
||||
return
|
||||
|
||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||
logger.info("[Manager] invoking runs.wait(thread_id=%s, text_len=%d)", thread_id, len(msg.text or ""))
|
||||
run_kwargs: dict[str, Any] = {
|
||||
"input": {"messages": [human_message]},
|
||||
"config": run_config,
|
||||
@@ -1236,7 +1329,7 @@ class ChannelManager:
|
||||
run_context: dict[str, Any],
|
||||
human_message: dict[str, Any],
|
||||
) -> None:
|
||||
logger.info("[Manager] invoking runs.stream(thread_id=%s, text=%r)", thread_id, msg.text[:100])
|
||||
logger.info("[Manager] invoking runs.stream(thread_id=%s, text_len=%d)", thread_id, len(msg.text or ""))
|
||||
|
||||
last_values: dict[str, Any] | list | None = None
|
||||
streamed_buffers: dict[str, str] = {}
|
||||
|
||||
@@ -90,15 +90,8 @@ class SlackChannel(Channel):
|
||||
bot_token = self.config.get("bot_token", "")
|
||||
app_token = self.config.get("app_token", "")
|
||||
|
||||
if self._connection_repo is not None and self.config.get("event_delivery") == "http":
|
||||
if not bot_token:
|
||||
logger.error("Slack HTTP Events mode requires bot_token")
|
||||
return
|
||||
await self._initialize_operator_web_client(str(bot_token))
|
||||
self._loop = asyncio.get_event_loop()
|
||||
self._running = True
|
||||
self.bus.subscribe_outbound(self._on_outbound)
|
||||
logger.info("Slack channel started in HTTP Events mode")
|
||||
if self.config.get("event_delivery") == "http":
|
||||
logger.error("Slack HTTP Events mode is not supported by this channel adapter; use Socket Mode with app_token")
|
||||
return
|
||||
|
||||
if not bot_token or not app_token:
|
||||
@@ -319,7 +312,7 @@ class SlackChannel(Channel):
|
||||
asyncio.run_coroutine_threadsafe(
|
||||
self._bind_connection_from_connect_code(
|
||||
event=event,
|
||||
team_id=str(team_id or event.get("team") or ""),
|
||||
team_id=str(team_id or ""),
|
||||
code=connect_code,
|
||||
),
|
||||
self._loop,
|
||||
@@ -343,6 +336,12 @@ class SlackChannel(Channel):
|
||||
text=text,
|
||||
msg_type=msg_type,
|
||||
thread_ts=thread_ts,
|
||||
metadata={
|
||||
# team_id is already resolved (payload team_id/team, else event team) by the caller.
|
||||
"team_id": team_id,
|
||||
"message_id": event.get("ts"),
|
||||
"client_msg_id": event.get("client_msg_id"),
|
||||
},
|
||||
)
|
||||
inbound.topic_id = thread_ts
|
||||
|
||||
|
||||
@@ -503,6 +503,7 @@ class TelegramChannel(Channel):
|
||||
text=text,
|
||||
msg_type=InboundMessageType.COMMAND,
|
||||
thread_ts=msg_id,
|
||||
metadata={"message_id": msg_id},
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
@@ -546,6 +547,7 @@ class TelegramChannel(Channel):
|
||||
text=text,
|
||||
msg_type=InboundMessageType.CHAT,
|
||||
thread_ts=msg_id,
|
||||
metadata={"message_id": msg_id},
|
||||
)
|
||||
inbound.topic_id = topic_id
|
||||
inbound = await self._attach_connection_identity(inbound)
|
||||
|
||||
@@ -627,6 +627,7 @@ class WechatChannel(Channel):
|
||||
metadata={
|
||||
"context_token": context_token,
|
||||
"ilink_user_id": chat_id,
|
||||
"message_id": str(raw_message.get("message_id") or raw_message.get("msg_id") or "").strip(),
|
||||
"ref_msg": self._extract_ref_message(raw_message),
|
||||
"raw_message": raw_message,
|
||||
},
|
||||
|
||||
@@ -313,7 +313,11 @@ class WeComChannel(Channel):
|
||||
msg_type=inbound_type,
|
||||
thread_ts=msg_id,
|
||||
files=files or [],
|
||||
metadata={"aibotid": body.get("aibotid"), "chattype": body.get("chattype")},
|
||||
metadata={
|
||||
"aibotid": body.get("aibotid"),
|
||||
"chattype": body.get("chattype"),
|
||||
"message_id": msg_id,
|
||||
},
|
||||
)
|
||||
inbound.topic_id = user_id # keep the same thread
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ router = APIRouter(prefix="/api/channels", tags=["channel-connections"])
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_STATE_TTL_SECONDS = 600
|
||||
_MAX_PENDING_CONNECT_CODES_PER_PROVIDER = 5
|
||||
_MASKED_CREDENTIAL_VALUE = "********"
|
||||
|
||||
|
||||
@@ -332,12 +333,22 @@ async def _create_state(
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
) -> str:
|
||||
now = datetime.now(UTC)
|
||||
state = _new_binding_code()
|
||||
await repo.create_oauth_state(
|
||||
# Atomic delete-expired + count + insert so concurrent connect POSTs from one
|
||||
# owner cannot each see count < cap and all insert past the cap.
|
||||
inserted = await repo.create_oauth_state_within_cap(
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(seconds=_STATE_TTL_SECONDS),
|
||||
expires_at=now + timedelta(seconds=_STATE_TTL_SECONDS),
|
||||
max_pending=_MAX_PENDING_CONNECT_CODES_PER_PROVIDER,
|
||||
now=now,
|
||||
)
|
||||
if not inserted:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail="Too many pending channel connection codes. Wait for existing codes to expire or use one of them.",
|
||||
)
|
||||
return state
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy import delete, func, select, text, update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
|
||||
@@ -279,16 +279,128 @@ class ChannelConnectionRepository:
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
|
||||
async def count_oauth_states(self, *, owner_user_id: str, provider: str) -> int:
|
||||
async def create_oauth_state_within_cap(
|
||||
self,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
state: str,
|
||||
expires_at: datetime,
|
||||
max_pending: int,
|
||||
now: datetime | None = None,
|
||||
code_verifier: str | None = None,
|
||||
nonce_hash: str | None = None,
|
||||
redirect_after: str | None = None,
|
||||
requested_scopes: list[str] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> bool:
|
||||
"""Atomically enforce the per-(owner, provider) pending cap, then insert.
|
||||
|
||||
delete-expired + count + insert run in a single transaction serialized
|
||||
per (owner, provider), so concurrent connect requests cannot each
|
||||
observe ``count < max_pending`` and all insert (which would leak past
|
||||
the cap). PostgreSQL takes a transaction-scoped advisory lock; SQLite
|
||||
serializes writers through the write lock the leading DELETE acquires.
|
||||
|
||||
Returns ``True`` when the row was inserted, ``False`` when the cap is
|
||||
already reached.
|
||||
"""
|
||||
current_time = now or datetime.now(UTC)
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(
|
||||
await self._serialize_oauth_owner_scope(session, owner_user_id, provider)
|
||||
# Prune only this owner/provider's expired codes (the ones that affect
|
||||
# this cap), not every user's — avoids a global DELETE on each connect
|
||||
# POST. Issuing this write first also takes the SQLite database write
|
||||
# lock so the count below cannot race a concurrent inserter between
|
||||
# count and commit. Stale codes for other owners are pruned globally
|
||||
# by consume_oauth_state / delete_expired_oauth_states.
|
||||
await session.execute(
|
||||
delete(ChannelOAuthStateRow).where(
|
||||
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
||||
ChannelOAuthStateRow.provider == provider,
|
||||
ChannelOAuthStateRow.expires_at < current_time,
|
||||
)
|
||||
)
|
||||
pending = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(ChannelOAuthStateRow)
|
||||
.where(
|
||||
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
||||
ChannelOAuthStateRow.provider == provider,
|
||||
ChannelOAuthStateRow.consumed_at.is_(None),
|
||||
ChannelOAuthStateRow.expires_at >= current_time,
|
||||
)
|
||||
)
|
||||
if int(pending.scalar_one()) >= max_pending:
|
||||
await session.rollback()
|
||||
return False
|
||||
session.add(
|
||||
ChannelOAuthStateRow(
|
||||
state_hash=self.hash_state(state),
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
code_verifier_encrypted=self._encrypt_optional_secret(code_verifier),
|
||||
nonce_hash=nonce_hash,
|
||||
redirect_after=redirect_after,
|
||||
requested_scopes_json=list(requested_scopes or []),
|
||||
metadata_json=dict(metadata or {}),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
async def _serialize_oauth_owner_scope(self, session: AsyncSession, owner_user_id: str, provider: str) -> None:
|
||||
"""Serialize concurrent pending-cap transactions for one (owner, provider).
|
||||
|
||||
On PostgreSQL this takes a transaction-scoped advisory lock so concurrent
|
||||
issuers run their count+insert one at a time. On SQLite the leading
|
||||
DELETE in the caller's transaction already acquires the database write
|
||||
lock, which serializes writers, so no extra lock is required.
|
||||
"""
|
||||
try:
|
||||
dialect = session.bind.dialect.name if session.bind is not None else ""
|
||||
except Exception:
|
||||
dialect = ""
|
||||
if dialect == "postgresql":
|
||||
await session.execute(text("SELECT pg_advisory_xact_lock(:lock_key)"), {"lock_key": self._oauth_scope_lock_key(owner_user_id, provider)})
|
||||
|
||||
@staticmethod
|
||||
def _oauth_scope_lock_key(owner_user_id: str, provider: str) -> int:
|
||||
digest = hashlib.sha256(f"{owner_user_id}\x00{provider}".encode()).digest()
|
||||
# 63-bit non-negative key for pg_advisory_xact_lock(bigint).
|
||||
return int.from_bytes(digest[:8], "big") & 0x7FFFFFFFFFFFFFFF
|
||||
|
||||
async def delete_expired_oauth_states(self, *, now: datetime | None = None) -> int:
|
||||
current_time = now or datetime.now(UTC)
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(delete(ChannelOAuthStateRow).where(ChannelOAuthStateRow.expires_at < current_time))
|
||||
await session.commit()
|
||||
return int(result.rowcount or 0)
|
||||
|
||||
async def count_oauth_states(
|
||||
self,
|
||||
*,
|
||||
owner_user_id: str,
|
||||
provider: str,
|
||||
active_only: bool = False,
|
||||
now: datetime | None = None,
|
||||
) -> int:
|
||||
current_time = now or datetime.now(UTC)
|
||||
conditions = [
|
||||
ChannelOAuthStateRow.owner_user_id == owner_user_id,
|
||||
ChannelOAuthStateRow.provider == provider,
|
||||
]
|
||||
if active_only:
|
||||
conditions.extend(
|
||||
[
|
||||
ChannelOAuthStateRow.consumed_at.is_(None),
|
||||
ChannelOAuthStateRow.expires_at >= current_time,
|
||||
]
|
||||
)
|
||||
|
||||
async with self.session_factory() as session:
|
||||
result = await session.execute(select(func.count()).select_from(ChannelOAuthStateRow).where(*conditions))
|
||||
return int(result.scalar_one())
|
||||
|
||||
async def consume_oauth_state(
|
||||
|
||||
@@ -246,6 +246,77 @@ class TestChannelConnectionRepository:
|
||||
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
|
||||
assert [state.state_hash for state in states] == [repo.hash_state("active-state")]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_count_oauth_states_active_only_and_delete_expired(self, repo):
|
||||
now = datetime.now(UTC)
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
state="expired-state",
|
||||
expires_at=now - timedelta(minutes=1),
|
||||
)
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
state="active-state",
|
||||
expires_at=now + timedelta(minutes=5),
|
||||
)
|
||||
|
||||
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 1
|
||||
assert await repo.delete_expired_oauth_states(now=now) == 1
|
||||
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack") == 1
|
||||
# Pin that the surviving row is the active one (an inverted expiry
|
||||
# predicate would delete the active row, still return 1, and pass above).
|
||||
async with repo.session_factory() as session:
|
||||
survivors = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
|
||||
assert [row.state_hash for row in survivors] == [repo.hash_state("active-state")]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_oauth_state_within_cap_enforces_pending_cap(self, repo):
|
||||
now = datetime.now(UTC)
|
||||
expires = now + timedelta(minutes=5)
|
||||
|
||||
for i in range(3):
|
||||
inserted = await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state=f"code-{i}", expires_at=expires, max_pending=3, now=now)
|
||||
assert inserted is True
|
||||
|
||||
# Cap reached: the next issuance is rejected and nothing is inserted.
|
||||
assert await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state="code-over", expires_at=expires, max_pending=3, now=now) is False
|
||||
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 3
|
||||
|
||||
# Expired rows are pruned and free up capacity; a different owner is unaffected.
|
||||
assert await repo.create_oauth_state_within_cap(owner_user_id="bob", provider="slack", state="bob-1", expires_at=expires, max_pending=3, now=now) is True
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_oauth_state_within_cap_ignores_expired_rows(self, repo):
|
||||
now = datetime.now(UTC)
|
||||
# Three already-expired rows must not count against the cap.
|
||||
for i in range(3):
|
||||
await repo.create_oauth_state(owner_user_id="alice", provider="slack", state=f"old-{i}", expires_at=now - timedelta(minutes=1))
|
||||
|
||||
inserted = await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state="fresh", expires_at=now + timedelta(minutes=5), max_pending=3, now=now)
|
||||
assert inserted is True
|
||||
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_create_oauth_state_within_cap_does_not_leak_under_concurrency(self, repo):
|
||||
"""Concurrent issuance for one owner cannot push past the cap (willem #1)."""
|
||||
import anyio
|
||||
|
||||
now = datetime.now(UTC)
|
||||
expires = now + timedelta(minutes=5)
|
||||
results: list[bool] = []
|
||||
|
||||
async def issue(state: str) -> None:
|
||||
results.append(await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state=state, expires_at=expires, max_pending=3, now=now))
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
for i in range(8):
|
||||
tg.start_soon(issue, f"code-{i}")
|
||||
|
||||
assert sum(1 for ok in results if ok) == 3
|
||||
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 3
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_consume_oauth_state_is_one_time_even_under_concurrent_consumers(self, repo):
|
||||
import anyio
|
||||
|
||||
@@ -504,6 +504,27 @@ def test_connect_slack_returns_binding_command_and_persists_state(tmp_path):
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_binding_code_caps_pending_states_per_provider(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
responses = [client.post("/api/channels/slack/connect") for _ in range(6)]
|
||||
|
||||
assert [response.status_code for response in responses[:5]] == [200, 200, 200, 200, 200]
|
||||
assert responses[5].status_code == 429
|
||||
assert "Too many pending channel connection codes" in responses[5].json()["detail"]
|
||||
|
||||
async def count_states():
|
||||
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="slack")
|
||||
|
||||
assert anyio.run(count_states) == 5
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_discord_returns_binding_command_and_persists_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
|
||||
@@ -800,6 +800,126 @@ class TestChannelManager:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_dispatch_loop_dedupes_stable_provider_message_id(self, tmp_path):
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=tmp_path / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store)
|
||||
manager._client = _make_mock_langgraph_client()
|
||||
outbound_received: list[OutboundMessage] = []
|
||||
|
||||
async def capture_outbound(msg: OutboundMessage) -> None:
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
await manager.start()
|
||||
|
||||
def _slack_inbound(message_id: str) -> InboundMessage:
|
||||
# Distinct objects per publish, like a real provider redelivery.
|
||||
return InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C123",
|
||||
user_id="U123",
|
||||
text="sensitive prompt",
|
||||
topic_id="1710000000.000100",
|
||||
metadata={"team_id": "T123", "message_id": message_id},
|
||||
)
|
||||
|
||||
# Same stable message_id delivered twice -> processed once.
|
||||
await bus.publish_inbound(_slack_inbound("1710000000.000200"))
|
||||
await bus.publish_inbound(_slack_inbound("1710000000.000200"))
|
||||
await _wait_for(lambda: manager._client.runs.wait.call_count == 1 and len(outbound_received) == 1)
|
||||
await asyncio.sleep(0.05)
|
||||
assert manager._client.threads.create.call_count == 1
|
||||
assert manager._client.runs.wait.call_count == 1
|
||||
assert len(outbound_received) == 1
|
||||
|
||||
# Negative control: a *different* message_id must still be processed,
|
||||
# so an over-dedupe regression (dropping distinct messages) is caught.
|
||||
await bus.publish_inbound(_slack_inbound("1710000000.000999"))
|
||||
await _wait_for(lambda: manager._client.runs.wait.call_count == 2 and len(outbound_received) == 2)
|
||||
await asyncio.sleep(0.05)
|
||||
await manager.stop()
|
||||
|
||||
assert manager._client.runs.wait.call_count == 2
|
||||
assert len(outbound_received) == 2
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_inbound_dedupe_key_fails_closed_without_workspace(self):
|
||||
"""Without a workspace identifier, skip dedupe instead of collapsing workspaces (willem #3)."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
with_workspace = InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C1",
|
||||
user_id="U1",
|
||||
text="x",
|
||||
metadata={"team_id": "T1", "message_id": "m1"},
|
||||
)
|
||||
assert ChannelManager._inbound_dedupe_key(with_workspace) == ("slack", "T1", "C1", "m1")
|
||||
|
||||
without_workspace = InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C1",
|
||||
user_id="U1",
|
||||
text="x",
|
||||
metadata={"message_id": "m1"},
|
||||
)
|
||||
assert ChannelManager._inbound_dedupe_key(without_workspace) is None
|
||||
|
||||
def test_dispatch_loop_releases_dedupe_key_when_handling_fails(self, tmp_path):
|
||||
"""A transient handling failure must not black-hole a provider redelivery (ShenAC #1)."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=tmp_path / "store.json")
|
||||
manager = ChannelManager(bus=bus, store=store)
|
||||
client = _make_mock_langgraph_client()
|
||||
attempts = {"n": 0}
|
||||
|
||||
async def flaky_wait(*args, **kwargs):
|
||||
attempts["n"] += 1
|
||||
if attempts["n"] == 1:
|
||||
raise RuntimeError("transient gateway 503")
|
||||
return {"messages": [{"type": "human", "content": "hi"}, {"type": "ai", "content": "recovered"}]}
|
||||
|
||||
client.runs.wait = AsyncMock(side_effect=flaky_wait)
|
||||
manager._client = client
|
||||
|
||||
outbound_received: list[OutboundMessage] = []
|
||||
|
||||
async def capture_outbound(msg: OutboundMessage) -> None:
|
||||
outbound_received.append(msg)
|
||||
|
||||
bus.subscribe_outbound(capture_outbound)
|
||||
await manager.start()
|
||||
|
||||
inbound = InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C123",
|
||||
user_id="U123",
|
||||
text="hello",
|
||||
metadata={"team_id": "T123", "message_id": "m-1"},
|
||||
)
|
||||
|
||||
# First delivery fails transiently; the dedupe key must be released.
|
||||
await bus.publish_inbound(inbound)
|
||||
await _wait_for(lambda: attempts["n"] == 1 and len(outbound_received) >= 1)
|
||||
|
||||
# Provider redelivers the same message_id: it must be reprocessed, not dropped.
|
||||
await bus.publish_inbound(inbound)
|
||||
await _wait_for(lambda: attempts["n"] == 2)
|
||||
await asyncio.sleep(0.05)
|
||||
await manager.stop()
|
||||
|
||||
assert attempts["n"] == 2
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_handle_chat_outbound_preserves_inbound_metadata(self):
|
||||
"""DingTalk (and similar) need inbound metadata on outbound sends (e.g. sender_staff_id)."""
|
||||
from app.channels.manager import ChannelManager
|
||||
@@ -3752,7 +3872,7 @@ class TestWeComChannel:
|
||||
assert inbound.thread_ts == "msg-1"
|
||||
assert inbound.topic_id == "user-1"
|
||||
assert inbound.files == files
|
||||
assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single"}
|
||||
assert inbound.metadata == {"aibotid": "bot-1", "chattype": "single", "message_id": "msg-1"}
|
||||
assert channel._ws_frames["msg-1"] is frame
|
||||
assert channel._ws_stream_ids["msg-1"] == "stream-1"
|
||||
|
||||
|
||||
@@ -98,24 +98,13 @@ def test_slack_send_uses_connection_bot_token_when_connection_id_is_present():
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch):
|
||||
def test_slack_http_events_mode_is_rejected(monkeypatch, caplog):
|
||||
import anyio
|
||||
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
class FakeWebClient:
|
||||
def __init__(self, token: str) -> None:
|
||||
self.token = token
|
||||
self.messages: list[dict] = []
|
||||
|
||||
def auth_test(self):
|
||||
return {"user_id": "B-http"}
|
||||
|
||||
def chat_postMessage(self, **kwargs):
|
||||
self.messages.append(kwargs)
|
||||
|
||||
slack_sdk = ModuleType("slack_sdk")
|
||||
slack_sdk.WebClient = FakeWebClient
|
||||
slack_sdk.WebClient = object
|
||||
socket_mode = ModuleType("slack_sdk.socket_mode")
|
||||
socket_mode.SocketModeClient = object
|
||||
response = ModuleType("slack_sdk.socket_mode.response")
|
||||
@@ -129,26 +118,20 @@ def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch):
|
||||
bus=MessageBus(),
|
||||
config={
|
||||
"bot_token": "xoxb-operator",
|
||||
# Provide app_token too so the missing-token early return cannot
|
||||
# fire before the HTTP-mode guard — otherwise the state assertions
|
||||
# below would hold even if the guard were deleted.
|
||||
"app_token": "xapp-token",
|
||||
"event_delivery": "http",
|
||||
"connection_repo": MagicMock(),
|
||||
},
|
||||
)
|
||||
|
||||
with caplog.at_level("ERROR", logger="app.channels.slack"):
|
||||
await channel.start()
|
||||
assert channel._running is True
|
||||
assert channel._web_client is not None
|
||||
assert channel._web_client.token == "xoxb-operator"
|
||||
assert channel._bot_user_id == "B-http"
|
||||
|
||||
await channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100")
|
||||
|
||||
assert channel._web_client.messages == [
|
||||
{
|
||||
"channel": "C123",
|
||||
"text": "Slack connected to DeerFlow.",
|
||||
"thread_ts": "1710000000.000100",
|
||||
}
|
||||
]
|
||||
await channel.stop()
|
||||
assert channel._running is False
|
||||
assert channel._web_client is None
|
||||
assert "Slack HTTP Events mode is not supported" in caplog.text
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
Reference in New Issue
Block a user