mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-11 01:45:58 +00:00
226 lines
7.9 KiB
Python
226 lines
7.9 KiB
Python
"""Tests for per-user IM channel connection persistence."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import UTC, datetime, timedelta
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from deerflow.persistence.channel_connections import (
|
|
ChannelConnectionRepository,
|
|
ChannelConnectionRow,
|
|
ChannelCredentialCipher,
|
|
ChannelCredentialRow,
|
|
ChannelWebhookDeliveryRow,
|
|
)
|
|
|
|
|
|
@pytest.fixture
|
|
async def repo(tmp_path):
|
|
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
|
|
|
url = f"sqlite+aiosqlite:///{tmp_path / 'channels.db'}"
|
|
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
|
try:
|
|
yield ChannelConnectionRepository(
|
|
get_session_factory(),
|
|
cipher=ChannelCredentialCipher.from_key("test-encryption-key"),
|
|
)
|
|
finally:
|
|
await close_engine()
|
|
|
|
|
|
class TestChannelConnectionRepository:
|
|
@pytest.mark.anyio
|
|
async def test_connections_are_listed_per_owner(self, repo):
|
|
alice = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_account_id="U-alice",
|
|
external_account_name="Alice",
|
|
workspace_id="T1",
|
|
workspace_name="Team One",
|
|
scopes=["chat:write"],
|
|
)
|
|
await repo.upsert_connection(
|
|
owner_user_id="bob",
|
|
provider="slack",
|
|
external_account_id="U-bob",
|
|
external_account_name="Bob",
|
|
workspace_id="T1",
|
|
workspace_name="Team One",
|
|
scopes=["chat:write"],
|
|
)
|
|
|
|
results = await repo.list_connections("alice")
|
|
|
|
assert [item["id"] for item in results] == [alice["id"]]
|
|
assert results[0]["owner_user_id"] == "alice"
|
|
assert results[0]["provider"] == "slack"
|
|
assert results[0]["scopes"] == ["chat:write"]
|
|
assert "encrypted_access_token" not in results[0]
|
|
|
|
@pytest.mark.anyio
|
|
async def test_upsert_connection_updates_existing_provider_identity(self, repo):
|
|
first = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
external_account_name="Alice",
|
|
workspace_id=None,
|
|
workspace_name=None,
|
|
status="pending",
|
|
)
|
|
second = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
external_account_name="Alice Telegram",
|
|
workspace_id=None,
|
|
workspace_name=None,
|
|
status="connected",
|
|
)
|
|
|
|
assert second["id"] == first["id"]
|
|
assert second["status"] == "connected"
|
|
assert second["external_account_name"] == "Alice Telegram"
|
|
assert len(await repo.list_connections("alice")) == 1
|
|
|
|
@pytest.mark.anyio
|
|
async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo):
|
|
connection = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_account_id="U-alice",
|
|
workspace_id="T1",
|
|
)
|
|
expires_at = datetime.now(UTC) + timedelta(hours=1)
|
|
|
|
await repo.store_credentials(
|
|
connection["id"],
|
|
access_token="xoxb-secret-access-token",
|
|
refresh_token="secret-refresh-token",
|
|
token_type="Bearer",
|
|
expires_at=expires_at,
|
|
extra={"bot_user_id": "B123"},
|
|
)
|
|
|
|
async with repo.session_factory() as session:
|
|
row = (await session.execute(select(ChannelCredentialRow))).scalar_one()
|
|
assert row.encrypted_access_token is not None
|
|
assert "xoxb-secret-access-token" not in row.encrypted_access_token
|
|
assert "secret-refresh-token" not in (row.encrypted_refresh_token or "")
|
|
assert "B123" not in (row.encrypted_extra_json or "")
|
|
|
|
credentials = await repo.get_credentials(connection["id"])
|
|
|
|
assert credentials is not None
|
|
assert credentials["access_token"] == "xoxb-secret-access-token"
|
|
assert credentials["refresh_token"] == "secret-refresh-token"
|
|
assert credentials["token_type"] == "Bearer"
|
|
assert credentials["expires_at"] == expires_at
|
|
assert credentials["extra"] == {"bot_user_id": "B123"}
|
|
|
|
@pytest.mark.anyio
|
|
async def test_conversations_are_scoped_by_connection(self, repo):
|
|
alice = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_account_id="U-alice",
|
|
workspace_id="T1",
|
|
)
|
|
bob = await repo.upsert_connection(
|
|
owner_user_id="bob",
|
|
provider="slack",
|
|
external_account_id="U-bob",
|
|
workspace_id="T1",
|
|
)
|
|
|
|
await repo.set_thread_id(
|
|
connection_id=alice["id"],
|
|
owner_user_id="alice",
|
|
provider="slack",
|
|
external_conversation_id="C-shared",
|
|
external_topic_id="1710000000.000100",
|
|
thread_id="thread-alice",
|
|
)
|
|
await repo.set_thread_id(
|
|
connection_id=bob["id"],
|
|
owner_user_id="bob",
|
|
provider="slack",
|
|
external_conversation_id="C-shared",
|
|
external_topic_id="1710000000.000100",
|
|
thread_id="thread-bob",
|
|
)
|
|
|
|
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
|
|
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_disconnect_connection_revokes_owner_connection_and_removes_credentials(self, repo):
|
|
connection = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
)
|
|
await repo.store_credentials(connection["id"], access_token="secret-token")
|
|
|
|
disconnected = await repo.disconnect_connection(
|
|
connection_id=connection["id"],
|
|
owner_user_id="alice",
|
|
)
|
|
|
|
assert disconnected is True
|
|
async with repo.session_factory() as session:
|
|
connection_row = await session.get(ChannelConnectionRow, connection["id"])
|
|
credential_row = await session.get(ChannelCredentialRow, connection["id"])
|
|
assert connection_row is not None
|
|
assert connection_row.status == "revoked"
|
|
assert credential_row is None
|
|
assert (
|
|
await repo.find_connection_by_external_identity(
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
)
|
|
is None
|
|
)
|
|
|
|
@pytest.mark.anyio
|
|
async def test_disconnect_connection_is_owner_scoped(self, repo):
|
|
connection = await repo.upsert_connection(
|
|
owner_user_id="alice",
|
|
provider="telegram",
|
|
external_account_id="42",
|
|
)
|
|
|
|
disconnected = await repo.disconnect_connection(
|
|
connection_id=connection["id"],
|
|
owner_user_id="bob",
|
|
)
|
|
|
|
assert disconnected is False
|
|
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
|
|
|
|
@pytest.mark.anyio
|
|
async def test_record_webhook_delivery_returns_false_for_duplicate_delivery_id(self, repo):
|
|
first = await repo.record_webhook_delivery(
|
|
provider="slack",
|
|
delivery_id="Ev123",
|
|
payload_sha256="abc",
|
|
event_type="app_mention",
|
|
)
|
|
second = await repo.record_webhook_delivery(
|
|
provider="slack",
|
|
delivery_id="Ev123",
|
|
payload_sha256="abc",
|
|
event_type="app_mention",
|
|
)
|
|
|
|
assert first is True
|
|
assert second is False
|
|
async with repo.session_factory() as session:
|
|
rows = (await session.execute(select(ChannelWebhookDeliveryRow))).scalars().all()
|
|
assert len(rows) == 1
|
|
assert rows[0].event_type == "app_mention"
|