mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-18 13:46:02 +00:00
8c0830aea1
* fix(channels): add operational guardrails * make format * fix(channels): converge with #3582 to avoid merge-order conflicts Drop this PR's DingTalk INFO-log redaction and hand it to #3582, which already restructures that handler and will redact the same log there. This PR no longer touches dingtalk.py, so the two PRs can merge to main in any order without a conflict. For WeChat, drop the contested thread_ts priority reorder (review #3) and keep only what inbound dedupe needs: a server-stable message_id in the inbound metadata (message_id/msg_id, no client_id per review #6). This is a single added line inside the metadata dict, a region #3582 never touches, so it auto-merges regardless of order. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): address three correctness review findings 1. Connect-code cap was racy (willem #1): _create_state ran delete-expired, count, and insert as three separate transactions, so concurrent connect POSTs from one owner could each see count < cap and all insert past it. Add ChannelConnectionRepository.create_oauth_state_within_cap which does delete+count+insert in a single transaction serialized per (owner, provider) — Postgres via pg_advisory_xact_lock, SQLite via the write lock the leading DELETE takes — and have the router use it. 2. Inbound dedupe key fell back to "" workspace (willem #3): two workspaces delivering without team/guild/aibotid would collapse to the same key and dedupe each other's messages. _inbound_dedupe_key now fails closed (returns None) when no workspace identifier is present. 3. Dedupe key was recorded on receipt and never released on failure (ShenAC #1): a transient error (DB blip, Gateway 503) left the key in place for the full TTL, so a provider redelivery of the same message_id — exactly the retry dedupe should absorb — was silently dropped. _handle_message now releases the key in the unexpected-exception branch so redelivery can recover, while keeping record-on-receipt so retries during handling are still deduped. Tests: repo cap enforcement incl. concurrent-issuance non-leak; dedupe fail-closed; dedupe key release-on-failure redelivery recovery. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * fix(channels): address cleanup/efficiency and test review findings Efficiency / cleanup: - Dedupe key set drops client-generated ids (client_msg_id, client_id); keep only server-stable event_id/message_id/msg_id, which a provider's own redelivery preserves (ShenAC #6). Every provider already emits message_id. - TTL/overflow pruning of _recent_inbound_events is now O(k): switch to an OrderedDict and popitem(last=False) from the front instead of scanning all 4096 entries on every inbound (willem #4). - Log "received inbound" only after the dedupe check so a provider retrying N times no longer logs N accepts; document that manager dedupe covers the agent run/final answer, not provider ack side-effects (willem #5, ShenAC #2). - Slack drops the redundant `team_id or event.get("team")` fallback the caller already resolved (willem #6). - create_oauth_state_within_cap prunes only this owner/provider's expired codes instead of a global DELETE on every connect POST; global cleanup still runs on consume_oauth_state (willem #7). Tests: - Dedupe test uses tmp_path instead of a leaked mkdtemp, uses distinct objects per publish, and adds a negative control: a different message_id is still processed, catching over-dedupe regressions (willem #8, ShenAC #4). - Slack HTTP-mode rejection test supplies app_token so the missing-token early return can't mask the guard, giving the state assertions teeth (ShenAC #3). - count_oauth_states test pins that the active row survives, not just the count (ShenAC #5). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> * make format --------- Co-authored-by: Claude Opus 4.8 <noreply@anthropic.com>
403 lines
16 KiB
Python
403 lines
16 KiB
Python
"""Tests for per-user IM channel connection persistence."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from deerflow.persistence.channel_connections import (
|
|
ChannelConnectionRepository,
|
|
ChannelConnectionRow,
|
|
ChannelCredentialCipher,
|
|
ChannelCredentialRow,
|
|
ChannelOAuthStateRow,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
async def repo(tmp_path):
|
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
|
|
|
url = f"sqlite+aiosqlite:///{tmp_path / 'channels.db'}"
|
|
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
|
try:
|
|
yield ChannelConnectionRepository(
|
|
get_session_factory(),
|
|
cipher=ChannelCredentialCipher.from_key("test-encryption-key"),
|
|
)
|
|
finally:
|
|
await close_engine()
|
|
|
|
|
|
class TestChannelConnectionRepository:
|
|
@pytest.mark.anyio
|
|
async def test_connections_are_listed_per_owner(self, repo):
|
|
alice = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_account_id="U-alice",
|
|
external_account_name="Alice",
|
|
workspace_id="T1",
|
|
workspace_name="Team One",
|
|
scopes=["chat:write"],
|
|
)
|
|
await repo.upsert_connection(
|
|
owner_user_id="bob",
|
|
provider="slack",
|
|
external_account_id="U-bob",
|
|
external_account_name="Bob",
|
|
workspace_id="T1",
|
|
workspace_name="Team One",
|
|
scopes=["chat:write"],
|
|
)
|
|
|
|
results = await repo.list_connections("alice")
|
|
|
|
assert [item["id"] for item in results] == [alice["id"]]
|
|
assert results[0]["owner_user_id"] == "alice"
|
|
assert results[0]["provider"] == "slack"
|
|
assert results[0]["scopes"] == ["chat:write"]
|
|
assert "encrypted_access_token" not in results[0]
|
|
|
|
@pytest.mark.anyio
|
|
async def test_upsert_connection_updates_existing_provider_identity(self, repo):
|
|
first = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
external_account_name="Alice",
|
|
workspace_id=None,
|
|
workspace_name=None,
|
|
status="pending",
|
|
)
|
|
second = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
external_account_name="Alice Telegram",
|
|
workspace_id=None,
|
|
workspace_name=None,
|
|
status="connected",
|
|
)
|
|
|
|
assert second["id"] == first["id"]
|
|
assert second["status"] == "connected"
|
|
assert second["external_account_name"] == "Alice Telegram"
|
|
assert len(await repo.list_connections("alice")) == 1
|
|
|
|
@pytest.mark.anyio
|
|
async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo):
|
|
connection = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_account_id="U-alice",
|
|
workspace_id="T1",
|
|
)
|
|
expires_at = datetime.now(UTC) + timedelta(hours=1)
|
|
|
|
await repo.store_credentials(
|
|
connection["id"],
|
|
access_token="xoxb-secret-access-token",
|
|
refresh_token="secret-refresh-token",
|
|
token_type="Bearer",
|
|
expires_at=expires_at,
|
|
extra={"bot_user_id": "B123"},
|
|
)
|
|
|
|
async with repo.session_factory() as session:
|
|
row = (await session.execute(select(ChannelCredentialRow))).scalar_one()
|
|
assert row.encrypted_access_token is not None
|
|
assert "xoxb-secret-access-token" not in row.encrypted_access_token
|
|
assert "secret-refresh-token" not in (row.encrypted_refresh_token or "")
|
|
assert "B123" not in (row.encrypted_extra_json or "")
|
|
|
|
credentials = await repo.get_credentials(connection["id"])
|
|
|
|
assert credentials is not None
|
|
assert credentials["access_token"] == "xoxb-secret-access-token"
|
|
assert credentials["refresh_token"] == "secret-refresh-token"
|
|
assert credentials["token_type"] == "Bearer"
|
|
assert credentials["expires_at"] == expires_at
|
|
assert credentials["extra"] == {"bot_user_id": "B123"}
|
|
|
|
@pytest.mark.anyio
|
|
async def test_get_credentials_returns_none_when_decryption_fails(self, repo, caplog):
|
|
connection = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_account_id="U-alice",
|
|
workspace_id="T1",
|
|
)
|
|
await repo.store_credentials(connection["id"], access_token="xoxb-secret-access-token")
|
|
wrong_key_repo = ChannelConnectionRepository(
|
|
repo.session_factory,
|
|
cipher=ChannelCredentialCipher.from_key("wrong-encryption-key"),
|
|
)
|
|
|
|
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.channel_connections.sql"):
|
|
credentials = await wrong_key_repo.get_credentials(connection["id"])
|
|
|
|
assert credentials is None
|
|
assert any("Unable to decrypt channel connection credentials" in record.message for record in caplog.records)
|
|
|
|
@pytest.mark.anyio
|
|
async def test_conversations_are_scoped_by_connection(self, repo):
|
|
alice = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_account_id="U-alice",
|
|
workspace_id="T1",
|
|
)
|
|
bob = await repo.upsert_connection(
|
|
owner_user_id="bob",
|
|
provider="slack",
|
|
external_account_id="U-bob",
|
|
workspace_id="T1",
|
|
)
|
|
|
|
await repo.set_thread_id(
|
|
connection_id=alice["id"],
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_conversation_id="C-shared",
|
|
external_topic_id="1710000000.000100",
|
|
thread_id="thread-alice",
|
|
)
|
|
await repo.set_thread_id(
|
|
connection_id=bob["id"],
|
|
owner_user_id="bob",
|
|
provider="slack",
|
|
external_conversation_id="C-shared",
|
|
external_topic_id="1710000000.000100",
|
|
thread_id="thread-bob",
|
|
)
|
|
|
|
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
|
|
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_disconnect_connection_revokes_owner_connection_and_removes_credentials(self, repo):
|
|
connection = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
)
|
|
await repo.store_credentials(connection["id"], access_token="secret-token")
|
|
|
|
disconnected = await repo.disconnect_connection(
|
|
connection_id=connection["id"],
|
|
owner_user_id="alice",
|
|
)
|
|
|
|
assert disconnected is True
|
|
async with repo.session_factory() as session:
|
|
connection_row = await session.get(ChannelConnectionRow, connection["id"])
|
|
credential_row = await session.get(ChannelCredentialRow, connection["id"])
|
|
assert connection_row is not None
|
|
assert connection_row.status == "revoked"
|
|
assert credential_row is None
|
|
assert (
|
|
await repo.find_connection_by_external_identity(
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
)
|
|
is None
|
|
)
|
|
|
|
@pytest.mark.anyio
|
|
async def test_disconnect_connection_is_owner_scoped(self, repo):
|
|
connection = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
)
|
|
|
|
disconnected = await repo.disconnect_connection(
|
|
connection_id=connection["id"],
|
|
owner_user_id="bob",
|
|
)
|
|
|
|
assert disconnected is False
|
|
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_consume_oauth_state_deletes_expired_states(self, repo):
|
|
now = datetime.now(UTC)
|
|
await repo.create_oauth_state(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
state="expired-state",
|
|
expires_at=now - timedelta(minutes=1),
|
|
)
|
|
await repo.create_oauth_state(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
state="active-state",
|
|
expires_at=now + timedelta(minutes=5),
|
|
)
|
|
|
|
consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now)
|
|
|
|
assert consumed is None
|
|
async with repo.session_factory() as session:
|
|
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
|
|
assert [state.state_hash for state in states] == [repo.hash_state("active-state")]
|
|
|
|
@pytest.mark.anyio
|
|
async def test_count_oauth_states_active_only_and_delete_expired(self, repo):
|
|
now = datetime.now(UTC)
|
|
await repo.create_oauth_state(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
state="expired-state",
|
|
expires_at=now - timedelta(minutes=1),
|
|
)
|
|
await repo.create_oauth_state(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
state="active-state",
|
|
expires_at=now + timedelta(minutes=5),
|
|
)
|
|
|
|
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 1
|
|
assert await repo.delete_expired_oauth_states(now=now) == 1
|
|
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack") == 1
|
|
# Pin that the surviving row is the active one (an inverted expiry
|
|
# predicate would delete the active row, still return 1, and pass above).
|
|
async with repo.session_factory() as session:
|
|
survivors = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
|
|
assert [row.state_hash for row in survivors] == [repo.hash_state("active-state")]
|
|
|
|
@pytest.mark.anyio
|
|
async def test_create_oauth_state_within_cap_enforces_pending_cap(self, repo):
|
|
now = datetime.now(UTC)
|
|
expires = now + timedelta(minutes=5)
|
|
|
|
for i in range(3):
|
|
inserted = await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state=f"code-{i}", expires_at=expires, max_pending=3, now=now)
|
|
assert inserted is True
|
|
|
|
# Cap reached: the next issuance is rejected and nothing is inserted.
|
|
assert await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state="code-over", expires_at=expires, max_pending=3, now=now) is False
|
|
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 3
|
|
|
|
# Expired rows are pruned and free up capacity; a different owner is unaffected.
|
|
assert await repo.create_oauth_state_within_cap(owner_user_id="bob", provider="slack", state="bob-1", expires_at=expires, max_pending=3, now=now) is True
|
|
|
|
@pytest.mark.anyio
|
|
async def test_create_oauth_state_within_cap_ignores_expired_rows(self, repo):
|
|
now = datetime.now(UTC)
|
|
# Three already-expired rows must not count against the cap.
|
|
for i in range(3):
|
|
await repo.create_oauth_state(owner_user_id="alice", provider="slack", state=f"old-{i}", expires_at=now - timedelta(minutes=1))
|
|
|
|
inserted = await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state="fresh", expires_at=now + timedelta(minutes=5), max_pending=3, now=now)
|
|
assert inserted is True
|
|
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 1
|
|
|
|
@pytest.mark.anyio
|
|
async def test_create_oauth_state_within_cap_does_not_leak_under_concurrency(self, repo):
|
|
"""Concurrent issuance for one owner cannot push past the cap (willem #1)."""
|
|
import anyio
|
|
|
|
now = datetime.now(UTC)
|
|
expires = now + timedelta(minutes=5)
|
|
results: list[bool] = []
|
|
|
|
async def issue(state: str) -> None:
|
|
results.append(await repo.create_oauth_state_within_cap(owner_user_id="alice", provider="slack", state=state, expires_at=expires, max_pending=3, now=now))
|
|
|
|
async with anyio.create_task_group() as tg:
|
|
for i in range(8):
|
|
tg.start_soon(issue, f"code-{i}")
|
|
|
|
assert sum(1 for ok in results if ok) == 3
|
|
assert await repo.count_oauth_states(owner_user_id="alice", provider="slack", active_only=True, now=now) == 3
|
|
|
|
@pytest.mark.anyio
|
|
async def test_consume_oauth_state_is_one_time_even_under_concurrent_consumers(self, repo):
|
|
import anyio
|
|
|
|
now = datetime.now(UTC)
|
|
await repo.create_oauth_state(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
state="bind-once",
|
|
expires_at=now + timedelta(minutes=5),
|
|
)
|
|
|
|
results: list = []
|
|
|
|
async def consume():
|
|
results.append(await repo.consume_oauth_state(provider="slack", state="bind-once", now=now))
|
|
|
|
async with anyio.create_task_group() as tg:
|
|
tg.start_soon(consume)
|
|
tg.start_soon(consume)
|
|
|
|
consumed = [result for result in results if result is not None]
|
|
assert len(consumed) == 1
|
|
assert consumed[0]["owner_user_id"] == "alice"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_upsert_connection_retries_as_update_when_concurrent_insert_wins(self, repo):
|
|
"""A losing concurrent INSERT retries as an UPDATE instead of raising IntegrityError."""
|
|
first = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_account_id="U-race",
|
|
workspace_id="T-race",
|
|
status="pending",
|
|
)
|
|
|
|
real_factory = repo.session_factory
|
|
|
|
class _EmptyResult:
|
|
@staticmethod
|
|
def scalar_one_or_none():
|
|
return None
|
|
|
|
class MissFirstSelectSession:
|
|
"""Make the initial identity SELECT miss, as if a concurrent writer inserted after it."""
|
|
|
|
def __init__(self, session):
|
|
self._session = session
|
|
self._missed = False
|
|
|
|
def __getattr__(self, name):
|
|
return getattr(self._session, name)
|
|
|
|
async def execute(self, *args, **kwargs):
|
|
result = await self._session.execute(*args, **kwargs)
|
|
if not self._missed:
|
|
self._missed = True
|
|
return _EmptyResult()
|
|
return result
|
|
|
|
async def __aenter__(self):
|
|
await self._session.__aenter__()
|
|
return self
|
|
|
|
async def __aexit__(self, *args):
|
|
return await self._session.__aexit__(*args)
|
|
|
|
repo.session_factory = lambda: MissFirstSelectSession(real_factory())
|
|
try:
|
|
second = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_account_id="U-race",
|
|
workspace_id="T-race",
|
|
status="connected",
|
|
)
|
|
finally:
|
|
repo.session_factory = real_factory
|
|
|
|
assert second["id"] == first["id"]
|
|
assert second["status"] == "connected"
|
|
connections = await repo.list_connections("alice")
|
|
assert len(connections) == 1
|