mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-06-13 19:06:01 +00:00
feat(im): Add user-owned IM channel connections (#3487)
* Add user-owned IM channel connections * Fix dev startup and channel connect popup * Use async channel connect flow * Harden dev service daemon startup * Support local IM channel connections * Align IM connections with local channels * Fix safe user id digest algorithm * Address Copilot IM channel feedback * Address IM channel review comments * Support all integrated IM channel connections * Format additional channel connection tests * Keep unavailable channel connect buttons clickable * Fix IM channel provider icons * Add runtime setup for enabled IM channels * Guard global shortcut key handling * Keep configured IM channels editable * Avoid password autofill for channel secrets * Make channel threads visible to connection owners * Persist IM runtime config locally * Allow disconnecting runtime IM channels * Route no-auth channel sessions to local user * Use default user for auth-disabled local mode * Show IM channel source on threads * Prefill IM channel runtime config * Reflect IM channel runtime health * Ignore Feishu message read events * Ignore Feishu non-content message events * Let setup wizard enable IM channels * Fix frontend formatting after merge * Stabilize backend tests without local config * Isolate channel runtime config tests * Address channel connection review comments * Use sha256 user buckets with legacy migration * Ensure runtime IM channels are ready after restart * Persist disconnected IM channel state * Address channel connection review comments * Address channel connection review findings Frontend connect flow: - Open the runtime-config dialog only when a provider still needs credentials; configured providers go straight to the connect flow, so the binding-code/deep-link path is reachable from the UI again. - After saving credentials, continue into the connect flow when a user binding is still required (multi-user mode) instead of stopping at a "Connected" toast. - Extract shared provider-state helpers to core/channels/provider-state and add unit + e2e coverage for the direct-connect and configure-then-connect paths. Provider status semantics: - Report connection_status from the user's newest connection row; with no binding it is not_connected, except in auth-disabled local mode where a configured running channel is effectively connected. Concurrency and event-loop correctness: - Offload ChannelRuntimeConfigStore construction and writes, channel service construction, and Slack connection replies to threads; add a tests/blocking_io/ anchor for the runtime-config handlers. - Consume binding codes with a conditional UPDATE so a code can only be used once under concurrent workers; retry upsert_connection as an update when a concurrent insert wins the unique constraint. - Serialize ensure_channel_ready per channel so concurrent provider polls cannot double-start a channel worker. Config and migration hardening: - Stop mutating the get_app_config()-cached Telegram provider config; the runtime store now owns the UI-entered bot username. - Register channel_connections in STARTUP_ONLY_FIELDS with the standardized startup-only Field description. - Match the legacy unsafe-id bucket by recomputing its exact SHA-1 name so another user's same-prefix bucket can never be migrated. - Remove the unused Telegram process_webhook_update path and document src/core/channels in the frontend docs. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * Address PR review comments on authz scoping and channel runtime Security (review feedback from ShenAC-SAC): - Scope internal-token callers to the connection owner carried in X-DeerFlow-Owner-User-Id instead of bypassing owner checks outright, in both require_permission(owner_check=True) and the stateless run endpoints. Internal callers keep access to their own and shared/legacy threads, and may claim a default-owned channel thread for its real owner, but a leaked internal token no longer grants cross-user thread access. - Require admin privileges for POST/DELETE /api/channels/{provider}/ runtime-config: runtime credentials and channel workers are instance-wide shared state (same model as the MCP config API). Read-only provider listing stays available to all users. Performance (review feedback from willem-bd): - Skip the redundant thread channel-metadata PATCH after the first successful backfill per thread. - Reuse the per-connection Slack WebClient until its token changes instead of constructing one per outbound message. - Reconcile channel readiness for all providers concurrently in GET /api/channels/providers. Also resolve the code-quality unused-import flag in the blocking-io anchor by pre-importing the channel service via importlib. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * Fix prettier formatting in provider-state test Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> * Reconcile UI runtime channel config with config reload on restart Main now reloads a channel's config.yaml entry on restart_channel() (#3514, issue #3497). Adapt the user-owned connection flow to coexist: - configure_channel() restarts with reload_config=False — the caller just supplied the authoritative config (browser-entered credentials that are never written to config.yaml), so a file reload must not clobber it with the stale on-disk entry. - _load_channel_config() re-applies the UI runtime-store overlay used at startup, so an operator-triggered restart keeps browser-entered credentials for channels without a config.yaml entry and does not resurrect a channel disconnected from the UI. - Offload the reload's disk IO (config.yaml + runtime store) with asyncio.to_thread, matching the blocking-IO policy on this branch. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com> --------- Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
This commit is contained in:
@@ -0,0 +1,106 @@
|
||||
"""Regression anchors: channel runtime-config handlers must not block the event loop.
|
||||
|
||||
``configure_channel_provider_runtime`` and ``disconnect_channel_provider_runtime``
|
||||
persist UI-entered channel credentials through ``ChannelRuntimeConfigStore``,
|
||||
whose construction reads its JSON file and whose setters rewrite it
|
||||
(``json.dump`` + ``Path.replace`` + ``chmod``). The handlers offload both via
|
||||
``asyncio.to_thread``; if that regresses back onto the event loop, the strict
|
||||
Blockbuster gate raises ``BlockingError`` and these tests fail.
|
||||
|
||||
The handlers are invoked directly with a minimal Starlette ``Request`` so the
|
||||
surface under test is exactly the router's own IO, mirroring
|
||||
``test_agents_router``. Test-side seeding/inspection is offloaded with
|
||||
``asyncio.to_thread``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import importlib
|
||||
from types import SimpleNamespace
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
|
||||
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||
from app.gateway.routers.channel_connections import (
|
||||
ChannelRuntimeConfigRequest,
|
||||
configure_channel_provider_runtime,
|
||||
disconnect_channel_provider_runtime,
|
||||
)
|
||||
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
# Pre-import: the handlers import this module lazily; the import's file IO
|
||||
# must happen at collection time, not on the event loop under the gate.
|
||||
importlib.import_module("app.channels.service")
|
||||
|
||||
pytestmark = pytest.mark.asyncio
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_app_config():
|
||||
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
|
||||
yield
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def _make_request(tmp_path) -> Request:
|
||||
app = FastAPI()
|
||||
app.state.channel_connections_config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
)
|
||||
app.state.channels_config = {}
|
||||
app.state.channel_connection_repo = _FakeRepo()
|
||||
store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||
app.state.channel_runtime_config_store = store
|
||||
user = SimpleNamespace(id=UUID("11111111-2222-3333-4444-555555555555"), system_role="admin")
|
||||
return Request({"type": "http", "app": app, "headers": [], "state": {"user": user}})
|
||||
|
||||
|
||||
class _FakeRepo:
|
||||
async def list_connections(self, owner_user_id):
|
||||
return []
|
||||
|
||||
|
||||
async def test_configure_runtime_channel_does_not_block_event_loop(tmp_path) -> None:
|
||||
request = await asyncio.to_thread(_make_request, tmp_path)
|
||||
|
||||
response = await configure_channel_provider_runtime(
|
||||
"slack",
|
||||
ChannelRuntimeConfigRequest(values={"bot_token": "xoxb-ui", "app_token": "xapp-ui"}),
|
||||
request,
|
||||
)
|
||||
|
||||
assert response.provider == "slack"
|
||||
store = request.app.state.channel_runtime_config_store
|
||||
assert await asyncio.to_thread(store.get_provider_config, "slack") == {
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
}
|
||||
|
||||
|
||||
async def test_disconnect_runtime_channel_does_not_block_event_loop(tmp_path) -> None:
|
||||
request = await asyncio.to_thread(_make_request, tmp_path)
|
||||
store = request.app.state.channel_runtime_config_store
|
||||
await asyncio.to_thread(
|
||||
store.set_provider_config,
|
||||
"slack",
|
||||
{"enabled": True, "bot_token": "xoxb-ui", "app_token": "xapp-ui"},
|
||||
)
|
||||
request.app.state.channels_config = {
|
||||
"slack": {"enabled": True, "bot_token": "xoxb-ui", "app_token": "xapp-ui"},
|
||||
}
|
||||
|
||||
response = await disconnect_channel_provider_runtime("slack", request)
|
||||
|
||||
assert response.provider == "slack"
|
||||
assert await asyncio.to_thread(store.get_provider_config, "slack") == {
|
||||
"enabled": False,
|
||||
"_runtime_disabled": True,
|
||||
}
|
||||
@@ -0,0 +1,251 @@
|
||||
"""Connection binding tests for browser-connectable IM channels beyond Telegram/Slack/Discord."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.channels.message_bus import InboundMessage, MessageBus
|
||||
|
||||
|
||||
async def _make_repo(tmp_path, name: str):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / f'{name}.db'}", sqlite_dir=str(tmp_path))
|
||||
return ChannelConnectionRepository(get_session_factory())
|
||||
|
||||
|
||||
async def _seed_state(repo, provider: str, state: str, owner_user_id: str = "deerflow-user-1") -> None:
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id=owner_user_id,
|
||||
provider=provider,
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
|
||||
|
||||
def test_feishu_connect_command_binds_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.feishu import FeishuChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "feishu")
|
||||
state = "feishu-bind-code"
|
||||
await _seed_state(repo, "feishu", state)
|
||||
channel = FeishuChannel(
|
||||
bus=MessageBus(),
|
||||
config={"app_id": "app", "app_secret": "secret", "connection_repo": repo},
|
||||
)
|
||||
channel._reply_card = AsyncMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
message_id="om-message-1",
|
||||
chat_id="oc-chat-1",
|
||||
user_id="ou-user-1",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "feishu"
|
||||
assert connections[0]["external_account_id"] == "ou-user-1"
|
||||
assert connections[0]["workspace_id"] == "oc-chat-1"
|
||||
channel._reply_card.assert_awaited_once_with("om-message-1", "Feishu connected to DeerFlow.")
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_dingtalk_connect_command_binds_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "dingtalk")
|
||||
state = "dingtalk-bind-code"
|
||||
await _seed_state(repo, "dingtalk", state)
|
||||
channel = DingTalkChannel(
|
||||
bus=MessageBus(),
|
||||
config={"client_id": "client", "client_secret": "secret", "connection_repo": repo},
|
||||
)
|
||||
channel._send_connection_reply = AsyncMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
conversation_type=_CONVERSATION_TYPE_GROUP,
|
||||
sender_staff_id="staff-user-1",
|
||||
sender_nick="Alice",
|
||||
conversation_id="cid-group-1",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "dingtalk"
|
||||
assert connections[0]["external_account_id"] == "staff-user-1"
|
||||
assert connections[0]["external_account_name"] == "Alice"
|
||||
assert connections[0]["workspace_id"] == "cid-group-1"
|
||||
channel._send_connection_reply.assert_awaited_once()
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_wechat_connect_command_binds_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.wechat import WechatChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "wechat")
|
||||
state = "wechat-bind-code"
|
||||
await _seed_state(repo, "wechat", state)
|
||||
channel = WechatChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "token", "connection_repo": repo},
|
||||
)
|
||||
channel._send_connection_reply = AsyncMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
chat_id="wx-user-1",
|
||||
context_token="ctx-1",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "wechat"
|
||||
assert connections[0]["external_account_id"] == "wx-user-1"
|
||||
assert connections[0]["workspace_id"] == "wx-user-1"
|
||||
channel._send_connection_reply.assert_awaited_once_with("wx-user-1", "ctx-1", "WeChat connected to DeerFlow.")
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_wecom_connect_command_binds_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "wecom")
|
||||
state = "wecom-bind-code"
|
||||
await _seed_state(repo, "wecom", state)
|
||||
channel = WeComChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_id": "bot", "bot_secret": "secret", "connection_repo": repo},
|
||||
)
|
||||
channel._ws_client = MagicMock()
|
||||
channel._ws_client.reply = AsyncMock()
|
||||
frame = {"body": {"aibotid": "bot-1", "chattype": "single"}}
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
frame=frame,
|
||||
user_id="wecom-user-1",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "wecom"
|
||||
assert connections[0]["external_account_id"] == "wecom-user-1"
|
||||
assert connections[0]["workspace_id"] == "bot-1"
|
||||
channel._ws_client.reply.assert_awaited_once_with(frame, {"msgtype": "text", "text": {"content": "WeCom connected to DeerFlow."}})
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_additional_channels_attach_owner_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.dingtalk import _CONVERSATION_TYPE_GROUP, DingTalkChannel
|
||||
from app.channels.feishu import FeishuChannel
|
||||
from app.channels.wechat import WechatChannel
|
||||
from app.channels.wecom import WeComChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path, "additional-identity")
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="feishu",
|
||||
external_account_id="ou-user-1",
|
||||
workspace_id="oc-chat-1",
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="dingtalk",
|
||||
external_account_id="staff-user-1",
|
||||
workspace_id="cid-group-1",
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="wechat",
|
||||
external_account_id="wx-user-1",
|
||||
workspace_id="wx-user-1",
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="wecom",
|
||||
external_account_id="wecom-user-1",
|
||||
workspace_id="bot-1",
|
||||
)
|
||||
|
||||
cases = [
|
||||
(
|
||||
FeishuChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||
InboundMessage(channel_name="feishu", chat_id="oc-chat-1", user_id="ou-user-1", text="hello"),
|
||||
),
|
||||
(
|
||||
DingTalkChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||
InboundMessage(
|
||||
channel_name="dingtalk",
|
||||
chat_id="cid-group-1",
|
||||
user_id="staff-user-1",
|
||||
text="hello",
|
||||
metadata={
|
||||
"conversation_type": _CONVERSATION_TYPE_GROUP,
|
||||
"conversation_id": "cid-group-1",
|
||||
},
|
||||
),
|
||||
),
|
||||
(
|
||||
WechatChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||
InboundMessage(channel_name="wechat", chat_id="wx-user-1", user_id="wx-user-1", text="hello"),
|
||||
),
|
||||
(
|
||||
WeComChannel(bus=MessageBus(), config={"connection_repo": repo}),
|
||||
InboundMessage(
|
||||
channel_name="wecom",
|
||||
chat_id="wecom-user-1",
|
||||
user_id="wecom-user-1",
|
||||
text="hello",
|
||||
metadata={"aibotid": "bot-1"},
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
for channel, inbound in cases:
|
||||
attached = await channel._attach_connection_identity(inbound)
|
||||
assert attached.owner_user_id == "deerflow-user-1"
|
||||
assert attached.connection_id
|
||||
assert (
|
||||
attached.workspace_id
|
||||
== {
|
||||
"feishu": "oc-chat-1",
|
||||
"dingtalk": "cid-group-1",
|
||||
"wechat": "wx-user-1",
|
||||
"wecom": "bot-1",
|
||||
}[channel.name]
|
||||
)
|
||||
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
@@ -280,6 +280,74 @@ def test_require_permission_denies_wrong_permission():
|
||||
assert "Permission denied" in response.json()["detail"]
|
||||
|
||||
|
||||
def _make_internal_owner_check_app():
|
||||
"""App with an owner_check route and a thread owned by ``alice``."""
|
||||
import asyncio
|
||||
|
||||
from fastapi import Request
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
|
||||
app = FastAPI()
|
||||
thread_store = MemoryThreadMetaStore(InMemoryStore())
|
||||
asyncio.run(thread_store.create("alice-thread", user_id="alice"))
|
||||
app.state.thread_store = thread_store
|
||||
|
||||
@app.get("/threads/{thread_id}")
|
||||
@require_permission("threads", "read", owner_check=True)
|
||||
async def endpoint(thread_id: str, request: Request):
|
||||
return {"ok": True}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def _internal_auth_context() -> AuthContext:
|
||||
from types import SimpleNamespace
|
||||
|
||||
from app.gateway.internal_auth import INTERNAL_SYSTEM_ROLE
|
||||
|
||||
user = SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)
|
||||
return AuthContext(user=user, permissions=[Permissions.THREADS_READ])
|
||||
|
||||
|
||||
def test_require_permission_internal_role_scoped_by_owner_header():
|
||||
"""An internal caller acting for the thread owner passes the owner check."""
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
|
||||
app = _make_internal_owner_check_app()
|
||||
with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()):
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/threads/alice-thread",
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "alice"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
|
||||
def test_require_permission_internal_role_denied_for_other_owner():
|
||||
"""The internal token must not grant access to another user's thread."""
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
|
||||
app = _make_internal_owner_check_app()
|
||||
with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()):
|
||||
with TestClient(app) as client:
|
||||
response = client.get(
|
||||
"/threads/alice-thread",
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "mallory"},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
def test_require_permission_internal_role_without_header_is_scoped_to_internal_user():
|
||||
"""With no owner header, internal callers are scoped like before the bypass."""
|
||||
app = _make_internal_owner_check_app()
|
||||
with patch("app.gateway.authz._authenticate", return_value=_internal_auth_context()):
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/threads/alice-thread")
|
||||
assert response.status_code == 404
|
||||
|
||||
|
||||
# ── Weak JWT secret warning ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
|
||||
@@ -39,6 +39,8 @@ def test_public_paths(path: str):
|
||||
"/api/threads/123/uploads",
|
||||
"/api/agents",
|
||||
"/api/channels",
|
||||
"/api/channels/providers",
|
||||
"/api/channels/slack/connect",
|
||||
"/api/runs/stream",
|
||||
"/api/threads/123/runs",
|
||||
"/api/v1/auth/me",
|
||||
@@ -183,7 +185,7 @@ def _make_auth_csrf_app():
|
||||
|
||||
@pytest.fixture
|
||||
def client(monkeypatch):
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "")
|
||||
return TestClient(_make_app())
|
||||
|
||||
|
||||
@@ -221,7 +223,7 @@ def test_auth_disabled_allows_protected_path_without_cookie(monkeypatch):
|
||||
assert res.json() == {"models": []}
|
||||
|
||||
|
||||
def test_auth_disabled_stamps_e2e_admin_user_without_cookie(monkeypatch):
|
||||
def test_auth_disabled_stamps_default_admin_user_without_cookie(monkeypatch):
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
client = TestClient(_make_app())
|
||||
|
||||
@@ -229,10 +231,10 @@ def test_auth_disabled_stamps_e2e_admin_user_without_cookie(monkeypatch):
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {
|
||||
"id": "e2e-user",
|
||||
"email": "e2e@test.local",
|
||||
"id": "default",
|
||||
"email": "default@test.local",
|
||||
"system_role": "admin",
|
||||
"context_user_id": "e2e-user",
|
||||
"context_user_id": "default",
|
||||
}
|
||||
|
||||
|
||||
@@ -244,8 +246,8 @@ def test_auth_disabled_auth_me_reuses_middleware_user_without_cookie(monkeypatch
|
||||
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {
|
||||
"id": "e2e-user",
|
||||
"email": "e2e@test.local",
|
||||
"id": "default",
|
||||
"email": "default@test.local",
|
||||
"system_role": "admin",
|
||||
"needs_setup": False,
|
||||
}
|
||||
@@ -329,7 +331,7 @@ def test_auth_disabled_startup_warning_when_effective(monkeypatch, caplog):
|
||||
warn_if_auth_disabled_enabled()
|
||||
|
||||
assert "authentication is bypassed" in caplog.text
|
||||
assert "e2e-user" in caplog.text
|
||||
assert "default" in caplog.text
|
||||
|
||||
|
||||
def test_auth_disabled_startup_warning_suppressed_in_explicit_production_env(monkeypatch, caplog):
|
||||
@@ -348,7 +350,8 @@ def test_protected_path_with_junk_cookie_rejected(client):
|
||||
"""Junk cookie → 401. Middleware strictly validates the JWT now
|
||||
(AUTH_TEST_PLAN test 7.5.8); it no longer silently passes bad
|
||||
tokens through to the route handler."""
|
||||
res = client.get("/api/models", cookies={"access_token": "some-token"})
|
||||
client.cookies.set("access_token", "some-token")
|
||||
res = client.get("/api/models")
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
|
||||
@@ -0,0 +1,56 @@
|
||||
"""Tests for user-facing IM channel connection configuration."""
|
||||
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
|
||||
def test_channel_connections_disabled_by_default():
|
||||
config = ChannelConnectionsConfig()
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.slack.enabled is False
|
||||
assert config.telegram.enabled is False
|
||||
assert config.discord.enabled is False
|
||||
assert config.feishu.enabled is False
|
||||
assert config.dingtalk.enabled is False
|
||||
assert config.wechat.enabled is False
|
||||
assert config.wecom.enabled is False
|
||||
|
||||
|
||||
def test_enabled_channel_connections_do_not_require_public_url_or_encryption_key():
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"telegram": {
|
||||
"enabled": True,
|
||||
"bot_username": "deerflow_bot",
|
||||
},
|
||||
"slack": {"enabled": True},
|
||||
"discord": {"enabled": True},
|
||||
"feishu": {"enabled": True},
|
||||
"dingtalk": {"enabled": True},
|
||||
"wechat": {"enabled": True},
|
||||
"wecom": {"enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.provider_status("telegram") == {"enabled": True, "configured": True}
|
||||
assert config.provider_status("slack") == {"enabled": True, "configured": True}
|
||||
assert config.provider_status("discord") == {"enabled": True, "configured": True}
|
||||
assert config.provider_status("feishu") == {"enabled": True, "configured": True}
|
||||
assert config.provider_status("dingtalk") == {"enabled": True, "configured": True}
|
||||
assert config.provider_status("wechat") == {"enabled": True, "configured": True}
|
||||
assert config.provider_status("wecom") == {"enabled": True, "configured": True}
|
||||
|
||||
|
||||
def test_provider_status_reports_disabled_and_unknown_providers():
|
||||
config = ChannelConnectionsConfig.model_validate({"enabled": True})
|
||||
|
||||
assert config.provider_status("slack") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("telegram") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("discord") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("feishu") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("dingtalk") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("wechat") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("wecom") == {"enabled": False, "configured": False}
|
||||
assert config.provider_status("unknown") == {"enabled": False, "configured": False}
|
||||
@@ -0,0 +1,331 @@
|
||||
"""Tests for per-user IM channel connection persistence."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from deerflow.persistence.channel_connections import (
|
||||
ChannelConnectionRepository,
|
||||
ChannelConnectionRow,
|
||||
ChannelCredentialCipher,
|
||||
ChannelCredentialRow,
|
||||
ChannelOAuthStateRow,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path):
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
url = f"sqlite+aiosqlite:///{tmp_path / 'channels.db'}"
|
||||
await init_engine("sqlite", url=url, sqlite_dir=str(tmp_path))
|
||||
try:
|
||||
yield ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("test-encryption-key"),
|
||||
)
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
class TestChannelConnectionRepository:
|
||||
@pytest.mark.anyio
|
||||
async def test_connections_are_listed_per_owner(self, repo):
|
||||
alice = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-alice",
|
||||
external_account_name="Alice",
|
||||
workspace_id="T1",
|
||||
workspace_name="Team One",
|
||||
scopes=["chat:write"],
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_account_id="U-bob",
|
||||
external_account_name="Bob",
|
||||
workspace_id="T1",
|
||||
workspace_name="Team One",
|
||||
scopes=["chat:write"],
|
||||
)
|
||||
|
||||
results = await repo.list_connections("alice")
|
||||
|
||||
assert [item["id"] for item in results] == [alice["id"]]
|
||||
assert results[0]["owner_user_id"] == "alice"
|
||||
assert results[0]["provider"] == "slack"
|
||||
assert results[0]["scopes"] == ["chat:write"]
|
||||
assert "encrypted_access_token" not in results[0]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_connection_updates_existing_provider_identity(self, repo):
|
||||
first = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
external_account_name="Alice",
|
||||
workspace_id=None,
|
||||
workspace_name=None,
|
||||
status="pending",
|
||||
)
|
||||
second = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
external_account_name="Alice Telegram",
|
||||
workspace_id=None,
|
||||
workspace_name=None,
|
||||
status="connected",
|
||||
)
|
||||
|
||||
assert second["id"] == first["id"]
|
||||
assert second["status"] == "connected"
|
||||
assert second["external_account_name"] == "Alice Telegram"
|
||||
assert len(await repo.list_connections("alice")) == 1
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_credentials_are_encrypted_at_rest_and_decrypted_by_repository(self, repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-alice",
|
||||
workspace_id="T1",
|
||||
)
|
||||
expires_at = datetime.now(UTC) + timedelta(hours=1)
|
||||
|
||||
await repo.store_credentials(
|
||||
connection["id"],
|
||||
access_token="xoxb-secret-access-token",
|
||||
refresh_token="secret-refresh-token",
|
||||
token_type="Bearer",
|
||||
expires_at=expires_at,
|
||||
extra={"bot_user_id": "B123"},
|
||||
)
|
||||
|
||||
async with repo.session_factory() as session:
|
||||
row = (await session.execute(select(ChannelCredentialRow))).scalar_one()
|
||||
assert row.encrypted_access_token is not None
|
||||
assert "xoxb-secret-access-token" not in row.encrypted_access_token
|
||||
assert "secret-refresh-token" not in (row.encrypted_refresh_token or "")
|
||||
assert "B123" not in (row.encrypted_extra_json or "")
|
||||
|
||||
credentials = await repo.get_credentials(connection["id"])
|
||||
|
||||
assert credentials is not None
|
||||
assert credentials["access_token"] == "xoxb-secret-access-token"
|
||||
assert credentials["refresh_token"] == "secret-refresh-token"
|
||||
assert credentials["token_type"] == "Bearer"
|
||||
assert credentials["expires_at"] == expires_at
|
||||
assert credentials["extra"] == {"bot_user_id": "B123"}
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_get_credentials_returns_none_when_decryption_fails(self, repo, caplog):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-alice",
|
||||
workspace_id="T1",
|
||||
)
|
||||
await repo.store_credentials(connection["id"], access_token="xoxb-secret-access-token")
|
||||
wrong_key_repo = ChannelConnectionRepository(
|
||||
repo.session_factory,
|
||||
cipher=ChannelCredentialCipher.from_key("wrong-encryption-key"),
|
||||
)
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="deerflow.persistence.channel_connections.sql"):
|
||||
credentials = await wrong_key_repo.get_credentials(connection["id"])
|
||||
|
||||
assert credentials is None
|
||||
assert any("Unable to decrypt channel connection credentials" in record.message for record in caplog.records)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_conversations_are_scoped_by_connection(self, repo):
|
||||
alice = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-alice",
|
||||
workspace_id="T1",
|
||||
)
|
||||
bob = await repo.upsert_connection(
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_account_id="U-bob",
|
||||
workspace_id="T1",
|
||||
)
|
||||
|
||||
await repo.set_thread_id(
|
||||
connection_id=alice["id"],
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_conversation_id="C-shared",
|
||||
external_topic_id="1710000000.000100",
|
||||
thread_id="thread-alice",
|
||||
)
|
||||
await repo.set_thread_id(
|
||||
connection_id=bob["id"],
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_conversation_id="C-shared",
|
||||
external_topic_id="1710000000.000100",
|
||||
thread_id="thread-bob",
|
||||
)
|
||||
|
||||
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
|
||||
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_disconnect_connection_revokes_owner_connection_and_removes_credentials(self, repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
)
|
||||
await repo.store_credentials(connection["id"], access_token="secret-token")
|
||||
|
||||
disconnected = await repo.disconnect_connection(
|
||||
connection_id=connection["id"],
|
||||
owner_user_id="alice",
|
||||
)
|
||||
|
||||
assert disconnected is True
|
||||
async with repo.session_factory() as session:
|
||||
connection_row = await session.get(ChannelConnectionRow, connection["id"])
|
||||
credential_row = await session.get(ChannelCredentialRow, connection["id"])
|
||||
assert connection_row is not None
|
||||
assert connection_row.status == "revoked"
|
||||
assert credential_row is None
|
||||
assert (
|
||||
await repo.find_connection_by_external_identity(
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_disconnect_connection_is_owner_scoped(self, repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
)
|
||||
|
||||
disconnected = await repo.disconnect_connection(
|
||||
connection_id=connection["id"],
|
||||
owner_user_id="bob",
|
||||
)
|
||||
|
||||
assert disconnected is False
|
||||
assert (await repo.list_connections("alice"))[0]["status"] == "connected"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_consume_oauth_state_deletes_expired_states(self, repo):
|
||||
now = datetime.now(UTC)
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
state="expired-state",
|
||||
expires_at=now - timedelta(minutes=1),
|
||||
)
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
state="active-state",
|
||||
expires_at=now + timedelta(minutes=5),
|
||||
)
|
||||
|
||||
consumed = await repo.consume_oauth_state(provider="slack", state="expired-state", now=now)
|
||||
|
||||
assert consumed is None
|
||||
async with repo.session_factory() as session:
|
||||
states = (await session.execute(select(ChannelOAuthStateRow))).scalars().all()
|
||||
assert [state.state_hash for state in states] == [repo.hash_state("active-state")]
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_consume_oauth_state_is_one_time_even_under_concurrent_consumers(self, repo):
|
||||
import anyio
|
||||
|
||||
now = datetime.now(UTC)
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
state="bind-once",
|
||||
expires_at=now + timedelta(minutes=5),
|
||||
)
|
||||
|
||||
results: list = []
|
||||
|
||||
async def consume():
|
||||
results.append(await repo.consume_oauth_state(provider="slack", state="bind-once", now=now))
|
||||
|
||||
async with anyio.create_task_group() as tg:
|
||||
tg.start_soon(consume)
|
||||
tg.start_soon(consume)
|
||||
|
||||
consumed = [result for result in results if result is not None]
|
||||
assert len(consumed) == 1
|
||||
assert consumed[0]["owner_user_id"] == "alice"
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_upsert_connection_retries_as_update_when_concurrent_insert_wins(self, repo):
|
||||
"""A losing concurrent INSERT retries as an UPDATE instead of raising IntegrityError."""
|
||||
first = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-race",
|
||||
workspace_id="T-race",
|
||||
status="pending",
|
||||
)
|
||||
|
||||
real_factory = repo.session_factory
|
||||
|
||||
class _EmptyResult:
|
||||
@staticmethod
|
||||
def scalar_one_or_none():
|
||||
return None
|
||||
|
||||
class MissFirstSelectSession:
|
||||
"""Make the initial identity SELECT miss, as if a concurrent writer inserted after it."""
|
||||
|
||||
def __init__(self, session):
|
||||
self._session = session
|
||||
self._missed = False
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._session, name)
|
||||
|
||||
async def execute(self, *args, **kwargs):
|
||||
result = await self._session.execute(*args, **kwargs)
|
||||
if not self._missed:
|
||||
self._missed = True
|
||||
return _EmptyResult()
|
||||
return result
|
||||
|
||||
async def __aenter__(self):
|
||||
await self._session.__aenter__()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, *args):
|
||||
return await self._session.__aexit__(*args)
|
||||
|
||||
repo.session_factory = lambda: MissFirstSelectSession(real_factory())
|
||||
try:
|
||||
second = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-race",
|
||||
workspace_id="T-race",
|
||||
status="connected",
|
||||
)
|
||||
finally:
|
||||
repo.session_factory = real_factory
|
||||
|
||||
assert second["id"] == first["id"]
|
||||
assert second["status"] == "connected"
|
||||
connections = await repo.list_connections("alice")
|
||||
assert len(connections) == 1
|
||||
@@ -0,0 +1,963 @@
|
||||
"""Router tests for browser-connectable IM channels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from tempfile import TemporaryDirectory
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
from _router_auth_helpers import make_authed_test_app
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||
from app.gateway.auth.models import User
|
||||
from app.gateway.routers import channel_connections
|
||||
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _stub_app_config(monkeypatch):
|
||||
"""Keep router tests independent from a developer-local config.yaml."""
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "0")
|
||||
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
|
||||
yield
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def _user() -> User:
|
||||
return User(
|
||||
id=UUID("11111111-2222-3333-4444-555555555555"),
|
||||
email="alice@example.com",
|
||||
password_hash="x",
|
||||
system_role="admin",
|
||||
)
|
||||
|
||||
|
||||
def _non_admin_user() -> User:
|
||||
return User(
|
||||
id=UUID("99999999-8888-7777-6666-555555555555"),
|
||||
email="bob@example.com",
|
||||
password_hash="x",
|
||||
system_role="user",
|
||||
)
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'router.db'}", sqlite_dir=str(tmp_path))
|
||||
return ChannelConnectionRepository(get_session_factory())
|
||||
|
||||
|
||||
def _make_app(
|
||||
config: ChannelConnectionsConfig,
|
||||
repo,
|
||||
channels_config: dict | None = None,
|
||||
*,
|
||||
runtime_config_store: ChannelRuntimeConfigStore | None = None,
|
||||
set_channels_config_state: bool = True,
|
||||
):
|
||||
app = make_authed_test_app(user_factory=_user)
|
||||
app.state.channel_connections_config = config
|
||||
app.state.channel_connection_repo = repo
|
||||
if set_channels_config_state:
|
||||
app.state.channels_config = channels_config or {}
|
||||
if runtime_config_store is None:
|
||||
runtime_config_dir = TemporaryDirectory()
|
||||
app.state.channel_runtime_config_tmpdir = runtime_config_dir
|
||||
runtime_config_store = ChannelRuntimeConfigStore(f"{runtime_config_dir.name}/runtime-config.json")
|
||||
app.state.channel_runtime_config_store = runtime_config_store
|
||||
app.include_router(channel_connections.router)
|
||||
return app
|
||||
|
||||
|
||||
def _enabled_connections_config() -> ChannelConnectionsConfig:
|
||||
return ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
|
||||
"slack": {"enabled": True},
|
||||
"discord": {"enabled": True},
|
||||
"feishu": {"enabled": True},
|
||||
"dingtalk": {"enabled": True},
|
||||
"wechat": {"enabled": True},
|
||||
"wecom": {"enabled": True},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _channels_config() -> dict:
|
||||
return {
|
||||
"telegram": {"enabled": True, "bot_token": "telegram-token"},
|
||||
"slack": {"enabled": True, "bot_token": "xoxb-operator", "app_token": "xapp-operator"},
|
||||
"discord": {"enabled": True, "bot_token": "discord-bot"},
|
||||
"feishu": {"enabled": True, "app_id": "feishu-app", "app_secret": "feishu-secret"},
|
||||
"dingtalk": {"enabled": True, "client_id": "dingtalk-client", "client_secret": "dingtalk-secret"},
|
||||
"wechat": {"enabled": True, "bot_token": "wechat-token"},
|
||||
"wecom": {"enabled": True, "bot_id": "wecom-bot", "bot_secret": "wecom-secret"},
|
||||
}
|
||||
|
||||
|
||||
def test_get_providers_only_returns_enabled_channels_and_setup_fields(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
"discord": {"enabled": False},
|
||||
}
|
||||
)
|
||||
app = _make_app(config, repo, {})
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["enabled"] is True
|
||||
assert [provider["provider"] for provider in body["providers"]] == ["slack"]
|
||||
assert body["providers"][0]["configured"] is False
|
||||
assert body["providers"][0]["connectable"] is False
|
||||
assert body["providers"][0]["credential_fields"] == [
|
||||
{
|
||||
"name": "bot_token",
|
||||
"label": "Bot token",
|
||||
"type": "password",
|
||||
"required": True,
|
||||
},
|
||||
{
|
||||
"name": "app_token",
|
||||
"label": "App token",
|
||||
"type": "password",
|
||||
"required": True,
|
||||
},
|
||||
]
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_uses_existing_channels_config(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["enabled"] is True
|
||||
by_provider = {item["provider"]: item for item in body["providers"]}
|
||||
assert set(by_provider) == {"telegram", "slack", "discord", "feishu", "dingtalk", "wechat", "wecom"}
|
||||
assert by_provider["telegram"]["configured"] is True
|
||||
assert by_provider["telegram"]["auth_mode"] == "deep_link"
|
||||
assert by_provider["telegram"]["credential_values"] == {
|
||||
"bot_token": "********",
|
||||
"bot_username": "deerflow_bot",
|
||||
}
|
||||
assert by_provider["slack"]["configured"] is True
|
||||
assert by_provider["slack"]["auth_mode"] == "binding_code"
|
||||
assert by_provider["slack"]["connection_status"] == "not_connected"
|
||||
assert by_provider["slack"]["credential_values"] == {
|
||||
"bot_token": "********",
|
||||
"app_token": "********",
|
||||
}
|
||||
assert by_provider["discord"]["configured"] is True
|
||||
assert by_provider["discord"]["auth_mode"] == "binding_code"
|
||||
assert by_provider["discord"]["credential_values"] == {"bot_token": "********"}
|
||||
assert by_provider["feishu"]["configured"] is True
|
||||
assert by_provider["feishu"]["auth_mode"] == "binding_code"
|
||||
assert by_provider["feishu"]["connection_status"] == "not_connected"
|
||||
assert by_provider["feishu"]["credential_values"] == {
|
||||
"app_id": "feishu-app",
|
||||
"app_secret": "********",
|
||||
}
|
||||
assert by_provider["dingtalk"]["configured"] is True
|
||||
assert by_provider["dingtalk"]["auth_mode"] == "binding_code"
|
||||
assert by_provider["dingtalk"]["credential_values"] == {
|
||||
"client_id": "dingtalk-client",
|
||||
"client_secret": "********",
|
||||
}
|
||||
assert by_provider["wechat"]["configured"] is True
|
||||
assert by_provider["wechat"]["auth_mode"] == "binding_code"
|
||||
assert by_provider["wechat"]["credential_values"] == {"bot_token": "********"}
|
||||
assert by_provider["wecom"]["configured"] is True
|
||||
assert by_provider["wecom"]["auth_mode"] == "binding_code"
|
||||
assert by_provider["wecom"]["credential_values"] == {
|
||||
"bot_id": "wecom-bot",
|
||||
"bot_secret": "********",
|
||||
}
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_degrades_when_persistence_is_unavailable(monkeypatch):
|
||||
monkeypatch.setattr(channel_connections, "get_session_factory", lambda: None)
|
||||
app = _make_app(_enabled_connections_config(), None, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
assert by_provider["slack"]["configured"] is True
|
||||
assert by_provider["slack"]["connectable"] is True
|
||||
assert by_provider["slack"]["connection_status"] == "not_connected"
|
||||
|
||||
|
||||
def test_get_providers_reports_connected_without_binding_in_auth_disabled_mode(tmp_path, monkeypatch):
|
||||
import anyio
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
monkeypatch.delenv("DEER_FLOW_ENV", raising=False)
|
||||
monkeypatch.delenv("ENVIRONMENT", raising=False)
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
# Auth-disabled local mode routes channel messages to the default user, so
|
||||
# a configured running channel is effectively connected without a binding.
|
||||
assert by_provider["slack"]["connection_status"] == "connected"
|
||||
assert by_provider["feishu"]["connection_status"] == "connected"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_reports_unconfigured_when_runtime_channel_is_missing(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, {"telegram": {"enabled": True, "bot_token": "telegram-token"}})
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
assert by_provider["telegram"]["configured"] is True
|
||||
assert by_provider["slack"]["configured"] is False
|
||||
assert by_provider["slack"]["connectable"] is False
|
||||
assert "Slack credentials" in by_provider["slack"]["unavailable_reason"]
|
||||
assert by_provider["discord"]["configured"] is False
|
||||
assert "Discord credentials" in by_provider["discord"]["unavailable_reason"]
|
||||
assert by_provider["feishu"]["configured"] is False
|
||||
assert "Feishu credentials" in by_provider["feishu"]["unavailable_reason"]
|
||||
assert by_provider["dingtalk"]["configured"] is False
|
||||
assert "DingTalk credentials" in by_provider["dingtalk"]["unavailable_reason"]
|
||||
assert by_provider["wechat"]["configured"] is False
|
||||
assert "WeChat credentials" in by_provider["wechat"]["unavailable_reason"]
|
||||
assert by_provider["wecom"]["configured"] is False
|
||||
assert "WeCom credentials" in by_provider["wecom"]["unavailable_reason"]
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_reports_configured_channel_not_running(tmp_path, monkeypatch):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
service = SimpleNamespace(
|
||||
get_status=lambda: {
|
||||
"service_running": True,
|
||||
"channels": {
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"running": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
assert by_provider["feishu"]["configured"] is True
|
||||
assert by_provider["feishu"]["connectable"] is False
|
||||
assert by_provider["feishu"]["connection_status"] == "not_connected"
|
||||
assert "configured but is not running" in by_provider["feishu"]["unavailable_reason"]
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_restarts_configured_channel_when_service_can_reconcile(tmp_path, monkeypatch):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"feishu": {"enabled": True},
|
||||
}
|
||||
)
|
||||
channels_config = {
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"app_id": "feishu-app",
|
||||
"app_secret": "feishu-secret",
|
||||
}
|
||||
}
|
||||
app = _make_app(config, repo, channels_config)
|
||||
status = {
|
||||
"service_running": True,
|
||||
"channels": {
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"running": False,
|
||||
}
|
||||
},
|
||||
}
|
||||
reconciled: list[tuple[str, dict]] = []
|
||||
|
||||
async def ensure_channel_ready(provider, runtime_config):
|
||||
reconciled.append((provider, dict(runtime_config)))
|
||||
status["channels"][provider]["running"] = True
|
||||
return True
|
||||
|
||||
service = SimpleNamespace(
|
||||
get_status=lambda: status,
|
||||
ensure_channel_ready=ensure_channel_ready,
|
||||
)
|
||||
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
assert by_provider["feishu"]["configured"] is True
|
||||
assert by_provider["feishu"]["connectable"] is True
|
||||
assert by_provider["feishu"]["connection_status"] == "not_connected"
|
||||
assert by_provider["feishu"]["unavailable_reason"] is None
|
||||
assert reconciled == [("feishu", channels_config["feishu"])]
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_providers_uses_newest_connection_status_per_provider(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connections():
|
||||
await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="slack",
|
||||
external_account_id="U-old",
|
||||
workspace_id="T-old",
|
||||
status="revoked",
|
||||
)
|
||||
await anyio.sleep(0.01)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="slack",
|
||||
external_account_id="U-new",
|
||||
workspace_id="T-new",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
anyio.run(seed_connections)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
assert by_provider["slack"]["connection_status"] == "connected"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_get_connections_returns_current_user_connections_only(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connections():
|
||||
await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
external_account_name="Alice",
|
||||
status="connected",
|
||||
)
|
||||
await repo.upsert_connection(
|
||||
owner_user_id="other-user",
|
||||
provider="telegram",
|
||||
external_account_id="99",
|
||||
external_account_name="Bob",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
anyio.run(seed_connections)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.get("/api/channels/connections")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert len(body["connections"]) == 1
|
||||
assert body["connections"][0]["provider"] == "telegram"
|
||||
assert body["connections"][0]["external_account_id"] == "42"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_telegram_returns_deep_link_and_persists_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/channels/telegram/connect")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["provider"] == "telegram"
|
||||
assert body["mode"] == "deep_link"
|
||||
assert body["url"].startswith("https://t.me/deerflow_bot?start=")
|
||||
assert body["code"]
|
||||
assert "/start" in body["instruction"]
|
||||
|
||||
async def count_states():
|
||||
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="telegram")
|
||||
|
||||
assert anyio.run(count_states) == 1
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_slack_returns_binding_command_and_persists_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/channels/slack/connect")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["provider"] == "slack"
|
||||
assert body["mode"] == "binding_code"
|
||||
assert body["url"] is None
|
||||
assert len(body["code"]) >= 22
|
||||
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Slack bot."
|
||||
|
||||
async def count_states():
|
||||
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="slack")
|
||||
|
||||
assert anyio.run(count_states) == 1
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_discord_returns_binding_command_and_persists_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/channels/discord/connect")
|
||||
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["provider"] == "discord"
|
||||
assert body["mode"] == "binding_code"
|
||||
assert body["url"] is None
|
||||
assert body["code"]
|
||||
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow Discord bot."
|
||||
|
||||
async def count_states():
|
||||
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider="discord")
|
||||
|
||||
assert anyio.run(count_states) == 1
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_existing_binding_code_channels_return_command_and_persist_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
providers = ["feishu", "dingtalk", "wechat", "wecom"]
|
||||
with TestClient(app) as client:
|
||||
responses = {provider: client.post(f"/api/channels/{provider}/connect") for provider in providers}
|
||||
|
||||
for provider, response in responses.items():
|
||||
expected_display_name = {
|
||||
"feishu": "Feishu",
|
||||
"dingtalk": "DingTalk",
|
||||
"wechat": "WeChat",
|
||||
"wecom": "WeCom",
|
||||
}[provider]
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert body["provider"] == provider
|
||||
assert body["mode"] == "binding_code"
|
||||
assert body["url"] is None
|
||||
assert len(body["code"]) >= 22
|
||||
assert body["instruction"] == f"Send /connect {body['code']} to the DeerFlow {expected_display_name} bot."
|
||||
|
||||
async def count_states(provider=provider):
|
||||
return await repo.count_oauth_states(owner_user_id=str(_user().id), provider=provider)
|
||||
|
||||
assert anyio.run(count_states) == 1
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_connect_unconfigured_runtime_channel_returns_400(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
app = _make_app(_enabled_connections_config(), repo, {})
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.post("/api/channels/slack/connect")
|
||||
|
||||
assert response.status_code == 400
|
||||
assert "Slack credentials" in response.json()["detail"]
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_configure_provider_runtime_credentials_enables_connect_without_file_edits(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
)
|
||||
app = _make_app(config, repo, {})
|
||||
|
||||
with TestClient(app) as client:
|
||||
configure_response = client.post(
|
||||
"/api/channels/slack/runtime-config",
|
||||
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||
)
|
||||
connect_response = client.post("/api/channels/slack/connect")
|
||||
|
||||
assert configure_response.status_code == 200
|
||||
configured = configure_response.json()
|
||||
assert configured["provider"] == "slack"
|
||||
assert configured["configured"] is True
|
||||
assert configured["connectable"] is True
|
||||
assert configured["connection_status"] == "not_connected"
|
||||
assert app.state.channels_config["slack"] == {
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
}
|
||||
assert connect_response.status_code == 200
|
||||
assert connect_response.json()["provider"] == "slack"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_runtime_config_endpoints_require_admin(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
)
|
||||
app = make_authed_test_app(user_factory=_non_admin_user)
|
||||
app.state.channel_connections_config = config
|
||||
app.state.channel_connection_repo = repo
|
||||
app.state.channels_config = {}
|
||||
runtime_config_dir = TemporaryDirectory()
|
||||
app.state.channel_runtime_config_tmpdir = runtime_config_dir
|
||||
app.state.channel_runtime_config_store = ChannelRuntimeConfigStore(f"{runtime_config_dir.name}/runtime-config.json")
|
||||
app.include_router(channel_connections.router)
|
||||
|
||||
with TestClient(app) as client:
|
||||
configure_response = client.post(
|
||||
"/api/channels/slack/runtime-config",
|
||||
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||
)
|
||||
disconnect_response = client.delete("/api/channels/slack/runtime-config")
|
||||
providers_response = client.get("/api/channels/providers")
|
||||
|
||||
assert configure_response.status_code == 403
|
||||
assert "Admin privileges" in configure_response.json()["detail"]
|
||||
assert disconnect_response.status_code == 403
|
||||
# Read-only provider listing stays available to regular users.
|
||||
assert providers_response.status_code == 200
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_configure_telegram_runtime_uses_new_bot_username_for_deep_link_without_mutating_config(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"telegram": {"enabled": True, "bot_username": "old_bot"},
|
||||
}
|
||||
)
|
||||
app = _make_app(config, repo, {})
|
||||
|
||||
with TestClient(app) as client:
|
||||
configure_response = client.post(
|
||||
"/api/channels/telegram/runtime-config",
|
||||
json={"values": {"bot_token": "tg-token", "bot_username": "new_bot"}},
|
||||
)
|
||||
connect_response = client.post("/api/channels/telegram/connect")
|
||||
|
||||
assert configure_response.status_code == 200
|
||||
assert configure_response.json()["credential_values"]["bot_username"] == "new_bot"
|
||||
assert connect_response.status_code == 200
|
||||
assert connect_response.json()["url"].startswith("https://t.me/new_bot?start=")
|
||||
# The original config object cached by get_app_config() must stay untouched.
|
||||
assert config.telegram.bot_username == "old_bot"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_configure_provider_runtime_credentials_survive_local_restart(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
)
|
||||
runtime_config_path = tmp_path / "channels" / "runtime-config.json"
|
||||
first_app = _make_app(
|
||||
config,
|
||||
repo,
|
||||
{},
|
||||
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
|
||||
)
|
||||
|
||||
with TestClient(first_app) as client:
|
||||
configure_response = client.post(
|
||||
"/api/channels/slack/runtime-config",
|
||||
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||
)
|
||||
|
||||
assert configure_response.status_code == 200
|
||||
|
||||
restarted_app = _make_app(
|
||||
config,
|
||||
repo,
|
||||
runtime_config_store=ChannelRuntimeConfigStore(runtime_config_path),
|
||||
set_channels_config_state=False,
|
||||
)
|
||||
|
||||
with TestClient(restarted_app) as client:
|
||||
response = client.get("/api/channels/providers")
|
||||
|
||||
assert response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in response.json()["providers"]}
|
||||
assert by_provider["slack"]["configured"] is True
|
||||
assert by_provider["slack"]["connectable"] is True
|
||||
assert by_provider["slack"]["connection_status"] == "not_connected"
|
||||
assert restarted_app.state.channels_config["slack"] == {
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
}
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_configure_provider_runtime_credentials_preserves_masked_secrets(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"feishu": {"enabled": True},
|
||||
}
|
||||
)
|
||||
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||
app = _make_app(
|
||||
config,
|
||||
repo,
|
||||
{
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"app_id": "old-app-id",
|
||||
"app_secret": "old-secret",
|
||||
}
|
||||
},
|
||||
runtime_config_store=runtime_config_store,
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
configure_response = client.post(
|
||||
"/api/channels/feishu/runtime-config",
|
||||
json={
|
||||
"values": {
|
||||
"app_id": "new-app-id",
|
||||
"app_secret": "********",
|
||||
}
|
||||
},
|
||||
)
|
||||
providers_response = client.get("/api/channels/providers")
|
||||
|
||||
assert configure_response.status_code == 200
|
||||
assert app.state.channels_config["feishu"] == {
|
||||
"enabled": True,
|
||||
"app_id": "new-app-id",
|
||||
"app_secret": "old-secret",
|
||||
}
|
||||
assert runtime_config_store.get_provider_config("feishu") == {
|
||||
"enabled": True,
|
||||
"app_id": "new-app-id",
|
||||
"app_secret": "old-secret",
|
||||
}
|
||||
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||
assert by_provider["feishu"]["credential_values"] == {
|
||||
"app_id": "new-app-id",
|
||||
"app_secret": "********",
|
||||
}
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_provider_runtime_config_clears_connected_state(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
)
|
||||
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
|
||||
|
||||
with TestClient(app) as client:
|
||||
configure_response = client.post(
|
||||
"/api/channels/slack/runtime-config",
|
||||
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||
)
|
||||
disconnect_response = client.delete("/api/channels/slack/runtime-config")
|
||||
providers_response = client.get("/api/channels/providers")
|
||||
|
||||
assert configure_response.status_code == 200
|
||||
assert disconnect_response.status_code == 200
|
||||
disconnected = disconnect_response.json()
|
||||
assert disconnected["provider"] == "slack"
|
||||
assert disconnected["configured"] is False
|
||||
assert disconnected["connectable"] is False
|
||||
assert disconnected["connection_status"] == "not_connected"
|
||||
assert runtime_config_store.get_provider_config("slack") == {
|
||||
"enabled": False,
|
||||
"_runtime_disabled": True,
|
||||
}
|
||||
|
||||
assert providers_response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||
assert by_provider["slack"]["connection_status"] == "not_connected"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_provider_runtime_config_suppresses_file_config_and_stops_channel(tmp_path, monkeypatch):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"feishu": {"enabled": True},
|
||||
}
|
||||
)
|
||||
set_app_config(
|
||||
AppConfig.model_validate(
|
||||
{
|
||||
"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"},
|
||||
"channels": {
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"app_id": "file-app-id",
|
||||
"app_secret": "file-secret",
|
||||
}
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||
runtime_config_store.set_provider_config(
|
||||
"feishu",
|
||||
{
|
||||
"enabled": True,
|
||||
"app_id": "runtime-app-id",
|
||||
"app_secret": "runtime-secret",
|
||||
},
|
||||
)
|
||||
service = SimpleNamespace(
|
||||
configure_channel=AsyncMock(return_value=True),
|
||||
remove_channel=AsyncMock(return_value=True),
|
||||
)
|
||||
monkeypatch.setattr("app.channels.service.get_channel_service", lambda: service)
|
||||
app = _make_app(
|
||||
config,
|
||||
repo,
|
||||
{
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"app_id": "runtime-app-id",
|
||||
"app_secret": "runtime-secret",
|
||||
}
|
||||
},
|
||||
runtime_config_store=runtime_config_store,
|
||||
)
|
||||
|
||||
with TestClient(app) as client:
|
||||
disconnect_response = client.delete("/api/channels/feishu/runtime-config")
|
||||
providers_response = client.get("/api/channels/providers")
|
||||
|
||||
assert disconnect_response.status_code == 200
|
||||
disconnected = disconnect_response.json()
|
||||
assert disconnected["provider"] == "feishu"
|
||||
assert disconnected["configured"] is False
|
||||
assert disconnected["connectable"] is False
|
||||
assert disconnected["connection_status"] == "not_connected"
|
||||
assert "feishu" not in app.state.channels_config
|
||||
service.remove_channel.assert_awaited_once_with("feishu")
|
||||
service.configure_channel.assert_not_awaited()
|
||||
|
||||
assert providers_response.status_code == 200
|
||||
by_provider = {item["provider"]: item for item in providers_response.json()["providers"]}
|
||||
assert by_provider["feishu"]["configured"] is False
|
||||
assert by_provider["feishu"]["connection_status"] == "not_connected"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_provider_runtime_config_revokes_current_user_provider_connections(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connection():
|
||||
await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="slack",
|
||||
external_account_id="U123",
|
||||
status="connected",
|
||||
)
|
||||
|
||||
anyio.run(seed_connection)
|
||||
config = ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
)
|
||||
runtime_config_store = ChannelRuntimeConfigStore(tmp_path / "channels" / "runtime-config.json")
|
||||
app = _make_app(config, repo, {}, runtime_config_store=runtime_config_store)
|
||||
|
||||
with TestClient(app) as client:
|
||||
configure_response = client.post(
|
||||
"/api/channels/slack/runtime-config",
|
||||
json={"values": {"bot_token": "xoxb-ui", "app_token": "xapp-ui"}},
|
||||
)
|
||||
disconnect_response = client.delete("/api/channels/slack/runtime-config")
|
||||
|
||||
assert configure_response.status_code == 200
|
||||
assert disconnect_response.status_code == 200
|
||||
|
||||
async def get_connection_status():
|
||||
return (await repo.list_connections(str(_user().id)))[0]["status"]
|
||||
|
||||
assert anyio.run(get_connection_status) == "revoked"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_connection_revokes_current_user_connection(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connection():
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id=str(_user().id),
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
status="connected",
|
||||
)
|
||||
return connection["id"]
|
||||
|
||||
connection_id = anyio.run(seed_connection)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.delete(f"/api/channels/connections/{connection_id}")
|
||||
|
||||
assert response.status_code == 204
|
||||
|
||||
async def get_connection_status():
|
||||
return (await repo.list_connections(str(_user().id)))[0]["status"]
|
||||
|
||||
assert anyio.run(get_connection_status) == "revoked"
|
||||
|
||||
anyio.run(repo.close)
|
||||
|
||||
|
||||
def test_disconnect_connection_is_current_user_scoped(tmp_path):
|
||||
import anyio
|
||||
|
||||
repo = anyio.run(_make_repo, tmp_path)
|
||||
|
||||
async def seed_connection():
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="other-user",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
status="connected",
|
||||
)
|
||||
return connection["id"]
|
||||
|
||||
connection_id = anyio.run(seed_connection)
|
||||
app = _make_app(_enabled_connections_config(), repo, _channels_config())
|
||||
|
||||
with TestClient(app) as client:
|
||||
response = client.delete(f"/api/channels/connections/{connection_id}")
|
||||
|
||||
assert response.status_code == 404
|
||||
|
||||
async def get_connection_status():
|
||||
return (await repo.list_connections("other-user"))[0]["status"]
|
||||
|
||||
assert anyio.run(get_connection_status) == "connected"
|
||||
|
||||
anyio.run(repo.close)
|
||||
+565
-10
@@ -487,6 +487,7 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None):
|
||||
|
||||
# threads.create() returns a Thread-like dict
|
||||
mock_client.threads.create = AsyncMock(return_value={"thread_id": thread_id})
|
||||
mock_client.threads.update = AsyncMock(return_value={"thread_id": thread_id})
|
||||
|
||||
# threads.get() returns thread info (succeeds by default)
|
||||
mock_client.threads.get = AsyncMock(return_value={"thread_id": thread_id})
|
||||
@@ -504,6 +505,17 @@ def _make_mock_langgraph_client(thread_id="test-thread-123", run_result=None):
|
||||
return mock_client
|
||||
|
||||
|
||||
async def _make_channel_connection_repo(tmp_path: Path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'channel-connections.db'}", sqlite_dir=str(tmp_path))
|
||||
return ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("test-channel-key"),
|
||||
)
|
||||
|
||||
|
||||
def _make_stream_part(event: str, data):
|
||||
return SimpleNamespace(event=event, data=data)
|
||||
|
||||
@@ -656,16 +668,34 @@ class TestChannelManager:
|
||||
|
||||
await manager.start()
|
||||
|
||||
inbound = InboundMessage(channel_name="test", chat_id="chat1", user_id="user1", text="hi")
|
||||
inbound = InboundMessage(
|
||||
channel_name="test",
|
||||
chat_id="chat1",
|
||||
user_id="user1",
|
||||
text="hi",
|
||||
topic_id="topic1",
|
||||
thread_ts="msg1",
|
||||
connection_id="conn1",
|
||||
)
|
||||
await bus.publish_inbound(inbound)
|
||||
await _wait_for(lambda: len(outbound_received) >= 1)
|
||||
await manager.stop()
|
||||
|
||||
# Thread should be created through Gateway
|
||||
mock_client.threads.create.assert_called_once()
|
||||
assert mock_client.threads.create.call_args.kwargs["metadata"] == {
|
||||
"channel_source": {
|
||||
"type": "im_channel",
|
||||
"provider": "test",
|
||||
"chat_id": "chat1",
|
||||
"topic_id": "topic1",
|
||||
"thread_ts": "msg1",
|
||||
"connection_id": "conn1",
|
||||
}
|
||||
}
|
||||
|
||||
# Thread ID should be stored
|
||||
thread_id = store.get_thread_id("test", "chat1")
|
||||
thread_id = store.get_thread_id("test", "chat1", topic_id="topic1")
|
||||
assert thread_id == "test-thread-123"
|
||||
|
||||
# runs.wait should be called with the thread_id
|
||||
@@ -883,10 +913,12 @@ class TestChannelManager:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_clarification_follow_up_preserves_history(self):
|
||||
def test_clarification_follow_up_preserves_history(self, monkeypatch):
|
||||
"""Conversation should continue after ask_clarification instead of resetting history."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
@@ -1954,10 +1986,12 @@ class TestChannelManager:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_same_topic_reuses_thread(self):
|
||||
def test_same_topic_reuses_thread(self, monkeypatch):
|
||||
"""Messages with the same topic_id should reuse the same DeerFlow thread."""
|
||||
from app.channels.manager import ChannelManager
|
||||
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
|
||||
async def go():
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
@@ -1990,6 +2024,17 @@ class TestChannelManager:
|
||||
|
||||
# threads.create should be called only ONCE (second message reuses the thread)
|
||||
mock_client.threads.create.assert_called_once()
|
||||
mock_client.threads.update.assert_called_once_with(
|
||||
"topic-thread-1",
|
||||
metadata={
|
||||
"channel_source": {
|
||||
"type": "im_channel",
|
||||
"provider": "test",
|
||||
"chat_id": "chat1",
|
||||
"topic_id": "topic-root-123",
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Both runs.wait calls should use the same thread_id
|
||||
assert mock_client.runs.wait.call_count == 2
|
||||
@@ -2325,8 +2370,9 @@ class TestResolveRunParamsUserId:
|
||||
store = ChannelStore(path=Path(tempfile.mkdtemp()) / "store.json")
|
||||
return ChannelManager(bus=bus, store=store)
|
||||
|
||||
def test_safe_user_id_is_passed_through(self):
|
||||
def test_safe_user_id_is_passed_through(self, monkeypatch):
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
msg = InboundMessage(channel_name="telegram", chat_id="c", user_id="123456", text="hi")
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
@@ -2334,10 +2380,78 @@ class TestResolveRunParamsUserId:
|
||||
assert run_context["user_id"] == "123456"
|
||||
assert run_context["channel_user_id"] == "123456"
|
||||
|
||||
def test_unsafe_user_id_is_normalized_but_raw_preserved(self):
|
||||
def test_connection_owner_user_id_takes_precedence_over_platform_user_id(self, monkeypatch):
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
msg = InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C123",
|
||||
user_id="U-platform",
|
||||
owner_user_id="deerflow-user-1",
|
||||
connection_id="connection-1",
|
||||
text="hi",
|
||||
)
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
|
||||
assert run_context["user_id"] == "deerflow-user-1"
|
||||
assert run_context["channel_user_id"] == "U-platform"
|
||||
|
||||
def test_auth_disabled_user_id_is_used_for_unbound_channel_messages(self, monkeypatch):
|
||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
|
||||
manager = self._manager()
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
|
||||
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
|
||||
assert run_context["channel_user_id"] == "U-platform"
|
||||
|
||||
from app.channels.manager import _owner_headers
|
||||
|
||||
headers = _owner_headers(msg)
|
||||
assert headers is not None
|
||||
assert headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == AUTH_DISABLED_USER_ID
|
||||
|
||||
def test_auth_disabled_user_id_overrides_bound_owner_for_local_visibility(self, monkeypatch):
|
||||
from app.gateway.auth_disabled import AUTH_DISABLED_USER_ID
|
||||
|
||||
manager = self._manager()
|
||||
monkeypatch.setenv("DEER_FLOW_AUTH_DISABLED", "1")
|
||||
msg = InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C123",
|
||||
user_id="U-platform",
|
||||
owner_user_id="real-user-from-old-binding",
|
||||
text="hi",
|
||||
)
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
|
||||
assert run_context["user_id"] == AUTH_DISABLED_USER_ID
|
||||
assert run_context["channel_user_id"] == "U-platform"
|
||||
|
||||
def test_unbound_channel_messages_keep_platform_user_id_when_auth_is_enabled(self, monkeypatch):
|
||||
from app.channels.manager import _owner_headers
|
||||
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
msg = InboundMessage(channel_name="slack", chat_id="C123", user_id="U-platform", text="hi")
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
|
||||
assert run_context["user_id"] == "U-platform"
|
||||
assert run_context["channel_user_id"] == "U-platform"
|
||||
assert _owner_headers(msg) is None
|
||||
|
||||
def test_unsafe_user_id_is_normalized_but_raw_preserved(self, monkeypatch):
|
||||
from deerflow.config.paths import make_safe_user_id
|
||||
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
raw = "user@example.com"
|
||||
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
|
||||
|
||||
@@ -2347,9 +2461,32 @@ class TestResolveRunParamsUserId:
|
||||
assert run_context["user_id"] != raw
|
||||
assert run_context["channel_user_id"] == raw
|
||||
|
||||
@pytest.mark.parametrize("raw_user_id", ["", None])
|
||||
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id):
|
||||
def test_unsafe_user_id_migrates_unique_legacy_bucket(self, tmp_path, monkeypatch):
|
||||
from deerflow.config.paths import Paths, make_safe_user_id
|
||||
|
||||
paths = Paths(tmp_path)
|
||||
legacy_dir = paths.base_dir / "users" / "user-example-com-63a710569261a24b"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
(legacy_dir / "memory.json").write_text('{"legacy": true}\n', encoding="utf-8")
|
||||
monkeypatch.setattr("deerflow.config.paths.get_paths", lambda: paths)
|
||||
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
raw = "user@example.com"
|
||||
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw, text="hi")
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
|
||||
safe = make_safe_user_id(raw)
|
||||
assert run_context["user_id"] == safe
|
||||
assert paths.user_dir(safe).exists()
|
||||
assert not legacy_dir.exists()
|
||||
assert (paths.user_dir(safe) / "memory.json").read_text(encoding="utf-8") == '{"legacy": true}\n'
|
||||
|
||||
@pytest.mark.parametrize("raw_user_id", ["", None])
|
||||
def test_empty_or_none_user_id_is_not_injected(self, raw_user_id, monkeypatch):
|
||||
manager = self._manager()
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
msg = InboundMessage(channel_name="feishu", chat_id="c", user_id=raw_user_id, text="hi")
|
||||
|
||||
_, _, run_context = manager._resolve_run_params(msg, "thread-1")
|
||||
@@ -2358,6 +2495,93 @@ class TestResolveRunParamsUserId:
|
||||
assert "channel_user_id" not in run_context
|
||||
|
||||
|
||||
class TestChannelManagerConnectionRouting:
|
||||
def test_connection_scoped_conversations_do_not_share_threads(self, tmp_path, monkeypatch):
|
||||
from app.channels.manager import ChannelManager
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
from deerflow.persistence.engine import close_engine
|
||||
|
||||
monkeypatch.delenv("DEER_FLOW_AUTH_DISABLED", raising=False)
|
||||
|
||||
async def go():
|
||||
repo = await _make_channel_connection_repo(tmp_path)
|
||||
alice = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="slack",
|
||||
external_account_id="U-alice",
|
||||
workspace_id="T1",
|
||||
)
|
||||
bob = await repo.upsert_connection(
|
||||
owner_user_id="bob",
|
||||
provider="slack",
|
||||
external_account_id="U-bob",
|
||||
workspace_id="T1",
|
||||
)
|
||||
|
||||
bus = MessageBus()
|
||||
store = ChannelStore(path=tmp_path / "legacy-store.json")
|
||||
manager = ChannelManager(bus=bus, store=store, connection_repo=repo)
|
||||
mock_client = _make_mock_langgraph_client()
|
||||
mock_client.threads.create = AsyncMock(
|
||||
side_effect=[
|
||||
{"thread_id": "thread-alice"},
|
||||
{"thread_id": "thread-bob"},
|
||||
]
|
||||
)
|
||||
manager._client = mock_client
|
||||
|
||||
await manager._handle_chat(
|
||||
InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C-shared",
|
||||
user_id="U-alice",
|
||||
owner_user_id="alice",
|
||||
connection_id=alice["id"],
|
||||
text="hello",
|
||||
thread_ts="1710000000.000100",
|
||||
topic_id="1710000000.000100",
|
||||
)
|
||||
)
|
||||
await manager._handle_chat(
|
||||
InboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C-shared",
|
||||
user_id="U-bob",
|
||||
owner_user_id="bob",
|
||||
connection_id=bob["id"],
|
||||
text="hello",
|
||||
thread_ts="1710000000.000100",
|
||||
topic_id="1710000000.000100",
|
||||
)
|
||||
)
|
||||
|
||||
assert await repo.get_thread_id(alice["id"], "C-shared", "1710000000.000100") == "thread-alice"
|
||||
assert await repo.get_thread_id(bob["id"], "C-shared", "1710000000.000100") == "thread-bob"
|
||||
assert store.list_entries() == []
|
||||
|
||||
first_context = mock_client.runs.wait.call_args_list[0].kwargs["context"]
|
||||
second_context = mock_client.runs.wait.call_args_list[1].kwargs["context"]
|
||||
assert first_context["user_id"] == "alice"
|
||||
assert first_context["channel_user_id"] == "U-alice"
|
||||
assert second_context["user_id"] == "bob"
|
||||
assert second_context["channel_user_id"] == "U-bob"
|
||||
|
||||
first_create_headers = mock_client.threads.create.call_args_list[0].kwargs["headers"]
|
||||
second_create_headers = mock_client.threads.create.call_args_list[1].kwargs["headers"]
|
||||
assert first_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
|
||||
assert second_create_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
|
||||
|
||||
first_run_headers = mock_client.runs.wait.call_args_list[0].kwargs["headers"]
|
||||
second_run_headers = mock_client.runs.wait.call_args_list[1].kwargs["headers"]
|
||||
assert first_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "alice"
|
||||
assert second_run_headers[INTERNAL_OWNER_USER_ID_HEADER_NAME] == "bob"
|
||||
|
||||
try:
|
||||
_run(go())
|
||||
finally:
|
||||
_run(close_engine())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ChannelService tests
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -3108,6 +3332,38 @@ class TestChannelService:
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_concurrent_ensure_channel_ready_starts_channel_once(self):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
async def go():
|
||||
service = ChannelService(
|
||||
channels_config={
|
||||
"telegram": {"enabled": True, "bot_token": "tg-token"},
|
||||
}
|
||||
)
|
||||
await service.manager.start()
|
||||
service._running = True
|
||||
start_calls = []
|
||||
|
||||
async def fake_start_channel(name, config):
|
||||
start_calls.append(name)
|
||||
await asyncio.sleep(0.01)
|
||||
service._channels[name] = SimpleNamespace(is_running=True, stop=AsyncMock())
|
||||
return True
|
||||
|
||||
service._start_channel = fake_start_channel
|
||||
|
||||
results = await asyncio.gather(
|
||||
service.ensure_channel_ready("telegram"),
|
||||
service.ensure_channel_ready("telegram"),
|
||||
)
|
||||
|
||||
assert results == [True, True]
|
||||
assert start_calls == ["telegram"]
|
||||
await service.stop()
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_session_config_is_forwarded_to_manager(self):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
@@ -3175,6 +3431,226 @@ class TestChannelService:
|
||||
|
||||
assert service._config == {"telegram": {"enabled": False}}
|
||||
|
||||
def test_from_app_config_does_not_create_runtime_channels_from_channel_connections(
|
||||
self,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
from app.channels.service import ChannelService
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(paths_module, "_paths", None)
|
||||
app_config = SimpleNamespace(
|
||||
model_extra={},
|
||||
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
|
||||
"slack": {"enabled": True},
|
||||
"discord": {"enabled": True},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
service = ChannelService.from_app_config(app_config)
|
||||
|
||||
assert service._config == {}
|
||||
|
||||
def test_from_app_config_preserves_existing_runtime_channels_with_channel_connections_enabled(
|
||||
self,
|
||||
monkeypatch,
|
||||
tmp_path,
|
||||
):
|
||||
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||
from app.channels.service import ChannelService
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(paths_module, "_paths", None)
|
||||
ChannelRuntimeConfigStore().set_provider_config(
|
||||
"slack",
|
||||
{
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
},
|
||||
)
|
||||
app_config = SimpleNamespace(
|
||||
model_extra={
|
||||
"channels": {
|
||||
"telegram": {"enabled": True, "bot_token": "telegram-token"},
|
||||
"slack": {"enabled": True, "bot_token": "xoxb", "app_token": "xapp"},
|
||||
"discord": {"enabled": True, "bot_token": "discord-bot-token"},
|
||||
}
|
||||
},
|
||||
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"telegram": {"enabled": True, "bot_username": "deerflow_bot"},
|
||||
"slack": {"enabled": True},
|
||||
"discord": {"enabled": True},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
service = ChannelService.from_app_config(app_config)
|
||||
|
||||
assert service._config["telegram"]["bot_token"] == "telegram-token"
|
||||
assert service._config["slack"]["app_token"] == "xapp"
|
||||
assert service._config["discord"]["bot_token"] == "discord-bot-token"
|
||||
|
||||
def test_from_app_config_loads_persisted_runtime_channel_config(self, monkeypatch, tmp_path):
|
||||
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||
from app.channels.service import ChannelService
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(paths_module, "_paths", None)
|
||||
ChannelRuntimeConfigStore().set_provider_config(
|
||||
"slack",
|
||||
{
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
},
|
||||
)
|
||||
app_config = SimpleNamespace(
|
||||
model_extra={},
|
||||
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"slack": {"enabled": True},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
service = ChannelService.from_app_config(app_config)
|
||||
|
||||
assert service._config["slack"] == {
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
}
|
||||
|
||||
def test_from_app_config_runtime_disconnect_suppresses_file_channel_config(self, monkeypatch, tmp_path):
|
||||
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||
from app.channels.service import ChannelService
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(paths_module, "_paths", None)
|
||||
ChannelRuntimeConfigStore().set_provider_config(
|
||||
"feishu",
|
||||
{
|
||||
"enabled": False,
|
||||
"_runtime_disabled": True,
|
||||
},
|
||||
)
|
||||
app_config = SimpleNamespace(
|
||||
model_extra={
|
||||
"channels": {
|
||||
"feishu": {
|
||||
"enabled": True,
|
||||
"app_id": "file-app-id",
|
||||
"app_secret": "file-secret",
|
||||
}
|
||||
}
|
||||
},
|
||||
channel_connections=ChannelConnectionsConfig.model_validate(
|
||||
{
|
||||
"enabled": True,
|
||||
"feishu": {"enabled": True},
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
service = ChannelService.from_app_config(app_config)
|
||||
|
||||
assert "feishu" not in service._config
|
||||
|
||||
def test_start_retries_configured_channel_until_ready(self, monkeypatch):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
class FlakyReadyChannel(Channel):
|
||||
starts = 0
|
||||
|
||||
def __init__(self, bus, config):
|
||||
super().__init__(name="slack", bus=bus, config=config)
|
||||
|
||||
async def start(self):
|
||||
type(self).starts += 1
|
||||
self._running = type(self).starts >= 2
|
||||
|
||||
async def stop(self):
|
||||
self._running = False
|
||||
|
||||
async def send(self, msg):
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(
|
||||
"deerflow.reflection.resolve_class",
|
||||
lambda import_path, base_class=None: FlakyReadyChannel,
|
||||
)
|
||||
|
||||
async def go():
|
||||
service = ChannelService(
|
||||
channels_config={
|
||||
"slack": {
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
try:
|
||||
await service.start()
|
||||
|
||||
assert FlakyReadyChannel.starts == 2
|
||||
assert service.get_status()["channels"]["slack"]["running"] is True
|
||||
finally:
|
||||
await service.stop()
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_connection_repo_is_forwarded_to_manager(self):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
repo = object()
|
||||
service = ChannelService(channels_config={}, connection_repo=repo)
|
||||
|
||||
assert service.manager._connection_repo is repo
|
||||
|
||||
def test_remove_channel_stops_running_channel_and_forgets_config(self):
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
async def go():
|
||||
service = ChannelService(
|
||||
channels_config={
|
||||
"slack": {
|
||||
"enabled": True,
|
||||
"bot_token": "xoxb-ui",
|
||||
"app_token": "xapp-ui",
|
||||
},
|
||||
}
|
||||
)
|
||||
channel = AsyncMock()
|
||||
service._channels["slack"] = channel
|
||||
service._running = True
|
||||
|
||||
assert await service.remove_channel("slack") is True
|
||||
|
||||
channel.stop.assert_awaited_once()
|
||||
assert "slack" not in service._channels
|
||||
assert "slack" not in service._config
|
||||
|
||||
_run(go())
|
||||
|
||||
def test_disabled_channel_with_string_creds_emits_warning(self, caplog):
|
||||
"""Warning is emitted when a channel has string credentials but enabled=false."""
|
||||
import logging
|
||||
@@ -3192,7 +3668,8 @@ class TestChannelService:
|
||||
await service.stop()
|
||||
|
||||
_run(go())
|
||||
assert any("wecom" in r.message and r.levelno == logging.WARNING for r in caplog.records)
|
||||
assert any("credentials configured but is disabled" in r.message and r.levelno == logging.WARNING for r in caplog.records)
|
||||
assert all("wecom" not in r.message for r in caplog.records)
|
||||
|
||||
def test_disabled_channel_with_int_creds_emits_warning(self, caplog):
|
||||
"""Warning is emitted even when YAML-parsed integer credentials are present."""
|
||||
@@ -3212,7 +3689,8 @@ class TestChannelService:
|
||||
await service.stop()
|
||||
|
||||
_run(go())
|
||||
assert any("telegram" in r.message and r.levelno == logging.WARNING for r in caplog.records)
|
||||
assert any("credentials configured but is disabled" in r.message and r.levelno == logging.WARNING for r in caplog.records)
|
||||
assert all("telegram" not in r.message for r in caplog.records)
|
||||
|
||||
def test_disabled_channel_without_creds_emits_info(self, caplog):
|
||||
"""Only an info log (no warning) is emitted when a channel is disabled with no credentials."""
|
||||
@@ -3267,6 +3745,83 @@ class TestChannelService:
|
||||
assert started_configs["feishu"]["app_secret"] == "new_secret"
|
||||
assert service._config["feishu"]["app_id"] == "new_id"
|
||||
|
||||
def test_configure_channel_keeps_explicit_config_over_stale_file_entry(self, monkeypatch):
|
||||
"""UI-entered runtime credentials must not be clobbered by a config.yaml reload.
|
||||
|
||||
configure_channel() receives the authoritative config (e.g. from the
|
||||
browser Connect/Modify dialog, never written to config.yaml), so its
|
||||
restart must skip the file reload that restart_channel() performs for
|
||||
operator-triggered restarts.
|
||||
"""
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
stale_file_config = {"feishu": {"enabled": True, "app_id": "file_id", "app_secret": "file_secret"}}
|
||||
|
||||
def mock_get_app_config():
|
||||
return SimpleNamespace(model_extra={"channels": stale_file_config})
|
||||
|
||||
monkeypatch.setattr("deerflow.config.app_config.get_app_config", mock_get_app_config)
|
||||
|
||||
service = ChannelService(channels_config={})
|
||||
service._running = True
|
||||
|
||||
started_configs = {}
|
||||
|
||||
async def mock_start_channel(name, config):
|
||||
started_configs[name] = config
|
||||
return True
|
||||
|
||||
service._start_channel = mock_start_channel
|
||||
|
||||
async def go():
|
||||
await service.configure_channel("feishu", {"enabled": True, "app_id": "ui_id", "app_secret": "ui_secret"})
|
||||
|
||||
_run(go())
|
||||
|
||||
assert started_configs["feishu"]["app_id"] == "ui_id"
|
||||
assert started_configs["feishu"]["app_secret"] == "ui_secret"
|
||||
assert service._config["feishu"]["app_id"] == "ui_id"
|
||||
|
||||
def test_restart_channel_reload_applies_runtime_store_overlay(self, monkeypatch, tmp_path):
|
||||
"""An operator-triggered restart keeps UI runtime-store credentials for
|
||||
channels that have no config.yaml entry."""
|
||||
from app.channels.runtime_config_store import ChannelRuntimeConfigStore
|
||||
from app.channels.service import ChannelService
|
||||
from deerflow.config import paths as paths_module
|
||||
from deerflow.config.channel_connections_config import ChannelConnectionsConfig
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_HOME", str(tmp_path))
|
||||
monkeypatch.setattr(paths_module, "_paths", None)
|
||||
ChannelRuntimeConfigStore().set_provider_config(
|
||||
"telegram",
|
||||
{"enabled": True, "bot_token": "store-token"},
|
||||
)
|
||||
|
||||
def mock_get_app_config():
|
||||
return SimpleNamespace(
|
||||
model_extra={"channels": {}},
|
||||
channel_connections=ChannelConnectionsConfig.model_validate({"enabled": True, "telegram": {"enabled": True, "bot_username": "deerflow_bot"}}),
|
||||
)
|
||||
|
||||
monkeypatch.setattr("deerflow.config.app_config.get_app_config", mock_get_app_config)
|
||||
|
||||
service = ChannelService(channels_config={})
|
||||
|
||||
started_configs = {}
|
||||
|
||||
async def mock_start_channel(name, config):
|
||||
started_configs[name] = config
|
||||
return True
|
||||
|
||||
service._start_channel = mock_start_channel
|
||||
|
||||
async def go():
|
||||
await service.restart_channel("telegram")
|
||||
|
||||
_run(go())
|
||||
|
||||
assert started_configs["telegram"]["bot_token"] == "store-token"
|
||||
|
||||
def test_restart_channel_falls_back_to_cached_config_on_error(self, monkeypatch):
|
||||
"""When get_app_config() fails, restart_channel uses cached config."""
|
||||
from app.channels.service import ChannelService
|
||||
|
||||
@@ -233,3 +233,15 @@ def test_non_auth_mutation_rejects_mismatched_double_submit_token():
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token mismatch."
|
||||
|
||||
|
||||
def test_channel_posts_require_double_submit_csrf():
|
||||
client = TestClient(_make_app(), base_url="https://deerflow.example")
|
||||
|
||||
response = client.post(
|
||||
"/api/channels/slack/connect",
|
||||
headers={"Origin": "https://deerflow.example"},
|
||||
)
|
||||
|
||||
assert response.status_code == 403
|
||||
assert response.json()["detail"] == "CSRF token missing. Include X-CSRF-Token header."
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
"""Discord connection routing tests."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.channels.discord import DiscordChannel
|
||||
from app.channels.message_bus import InboundMessage, MessageBus
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'discord.db'}", sqlite_dir=str(tmp_path))
|
||||
try:
|
||||
yield ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("discord-secret"),
|
||||
)
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_discord_inbound_attaches_owner_identity_from_user_level_connection(repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="alice",
|
||||
provider="discord",
|
||||
external_account_id="987",
|
||||
external_account_name="Alice",
|
||||
status="connected",
|
||||
)
|
||||
channel = DiscordChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "discord-bot", "connection_repo": repo},
|
||||
)
|
||||
inbound = InboundMessage(
|
||||
channel_name="discord",
|
||||
chat_id="C123",
|
||||
user_id="987",
|
||||
text="hello",
|
||||
)
|
||||
|
||||
attached = await channel._attach_connection_identity(inbound, guild_id="G123")
|
||||
|
||||
assert attached.connection_id == connection["id"]
|
||||
assert attached.owner_user_id == "alice"
|
||||
assert attached.workspace_id is None
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_discord_connect_command_binds_gateway_identity(repo):
|
||||
state = "discord-bind-code"
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="discord",
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
channel = DiscordChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "discord-bot", "connection_repo": repo},
|
||||
)
|
||||
message = MagicMock()
|
||||
message.author.id = 987
|
||||
message.author.display_name = "Alice"
|
||||
message.guild.id = 123
|
||||
message.guild.name = "Deer Guild"
|
||||
message.channel.id = 456
|
||||
message.channel.send = AsyncMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(message, state)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "discord"
|
||||
assert connections[0]["external_account_id"] == "987"
|
||||
assert connections[0]["external_account_name"] == "Alice"
|
||||
assert connections[0]["workspace_id"] == "123"
|
||||
assert connections[0]["workspace_name"] == "Deer Guild"
|
||||
assert connections[0]["metadata"]["channel_id"] == "456"
|
||||
message.channel.send.assert_awaited_once()
|
||||
@@ -73,6 +73,31 @@ def test_feishu_on_message_plain_text():
|
||||
assert mock_make_inbound.call_args[1]["text"] == "Hello world"
|
||||
|
||||
|
||||
def test_feishu_is_not_running_when_ws_thread_exits():
|
||||
bus = MessageBus()
|
||||
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
|
||||
channel._running = True
|
||||
channel._thread = MagicMock()
|
||||
channel._thread.is_alive.return_value = False
|
||||
|
||||
assert channel.is_running is False
|
||||
|
||||
|
||||
def test_feishu_event_handler_ignores_non_content_message_events():
|
||||
import lark_oapi as lark
|
||||
|
||||
bus = MessageBus()
|
||||
channel = FeishuChannel(bus, {"app_id": "test", "app_secret": "test"})
|
||||
|
||||
event_handler = channel._build_event_handler(lark)
|
||||
|
||||
assert "p2.im.message.receive_v1" in event_handler._processorMap
|
||||
assert "p2.im.message.message_read_v1" in event_handler._processorMap
|
||||
assert "p2.im.message.reaction.created_v1" in event_handler._processorMap
|
||||
assert "p2.im.message.reaction.deleted_v1" in event_handler._processorMap
|
||||
assert "p2.im.message.recalled_v1" in event_handler._processorMap
|
||||
|
||||
|
||||
def test_feishu_on_message_rich_text():
|
||||
bus = MessageBus()
|
||||
config = {"app_id": "test", "app_secret": "test"}
|
||||
|
||||
@@ -4,6 +4,18 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from deerflow.config.app_config import AppConfig, reset_app_config, set_app_config
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _stub_app_config():
|
||||
"""Keep run-context tests independent from a developer-local config.yaml."""
|
||||
set_app_config(AppConfig.model_validate({"sandbox": {"use": "deerflow.sandbox.local:LocalSandboxProvider"}}))
|
||||
yield
|
||||
reset_app_config()
|
||||
|
||||
|
||||
def test_format_sse_basic():
|
||||
from app.gateway.services import format_sse
|
||||
@@ -36,6 +48,12 @@ def test_format_sse_no_event_id():
|
||||
assert "id:" not in frame
|
||||
|
||||
|
||||
def test_sanitize_log_param_strips_control_characters():
|
||||
from app.gateway.utils import sanitize_log_param
|
||||
|
||||
assert sanitize_log_param("thread\nid\rwith\x00controls") == "threadidwithcontrols"
|
||||
|
||||
|
||||
def test_normalize_stream_modes_none():
|
||||
from app.gateway.services import normalize_stream_modes
|
||||
|
||||
@@ -474,6 +492,83 @@ def test_inject_authenticated_user_context_skips_internal_role():
|
||||
assert config["context"]["user_id"] == "channel-user-7"
|
||||
|
||||
|
||||
def test_start_run_uses_internal_owner_header_for_persistence(_stub_app_config):
|
||||
import asyncio
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from langgraph.store.memory import InMemoryStore
|
||||
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||
from app.gateway.services import start_run
|
||||
from deerflow.persistence.thread_meta.memory import MemoryThreadMetaStore
|
||||
from deerflow.runtime import RunManager
|
||||
from deerflow.runtime.runs.store.memory import MemoryRunStore
|
||||
from deerflow.runtime.user_context import get_effective_user_id
|
||||
|
||||
async def _scenario():
|
||||
run_store = MemoryRunStore()
|
||||
thread_store = MemoryThreadMetaStore(InMemoryStore())
|
||||
await thread_store.create("channel-thread", user_id="default", metadata={"legacy": True})
|
||||
run_manager = RunManager(store=run_store)
|
||||
state = SimpleNamespace(
|
||||
stream_bridge=SimpleNamespace(),
|
||||
run_manager=run_manager,
|
||||
checkpointer=InMemorySaver(),
|
||||
store=InMemoryStore(),
|
||||
run_event_store=SimpleNamespace(),
|
||||
run_events_config=None,
|
||||
thread_store=thread_store,
|
||||
)
|
||||
request = SimpleNamespace(
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
|
||||
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
|
||||
app=SimpleNamespace(state=state),
|
||||
)
|
||||
body = SimpleNamespace(
|
||||
assistant_id="lead_agent",
|
||||
input={"messages": [{"role": "human", "content": "hi"}]},
|
||||
metadata={},
|
||||
config=None,
|
||||
context=None,
|
||||
on_disconnect="cancel",
|
||||
multitask_strategy="reject",
|
||||
stream_mode=None,
|
||||
stream_subgraphs=False,
|
||||
interrupt_before=None,
|
||||
interrupt_after=None,
|
||||
)
|
||||
task_context: dict[str, str] = {}
|
||||
|
||||
async def fake_run_agent(*args, **kwargs):
|
||||
task_context["user_id"] = get_effective_user_id()
|
||||
|
||||
with (
|
||||
patch("app.gateway.services.resolve_agent_factory", return_value=object()),
|
||||
patch("app.gateway.services.run_agent", side_effect=fake_run_agent),
|
||||
):
|
||||
record = await start_run(body, "channel-thread", request)
|
||||
await record.task
|
||||
|
||||
owner_run = await run_store.get(record.run_id, user_id="owner-1")
|
||||
default_run = await run_store.get(record.run_id, user_id="default")
|
||||
owner_thread = await thread_store.get("channel-thread", user_id="owner-1")
|
||||
default_thread = await thread_store.get("channel-thread", user_id="default")
|
||||
return owner_run, default_run, owner_thread, default_thread, task_context
|
||||
|
||||
owner_run, default_run, owner_thread, default_thread, task_context = asyncio.run(_scenario())
|
||||
|
||||
assert owner_run is not None
|
||||
assert owner_run["user_id"] == "owner-1"
|
||||
assert default_run is None
|
||||
assert owner_thread is not None
|
||||
assert owner_thread["user_id"] == "owner-1"
|
||||
assert owner_thread["metadata"] == {"legacy": True}
|
||||
assert default_thread is None
|
||||
assert task_context["user_id"] == "owner-1"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# build_run_config — context / configurable precedence (LangGraph >= 0.6.0)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -33,3 +33,18 @@ def test_internal_auth_generates_process_local_fallback(monkeypatch):
|
||||
assert reloaded.is_valid_internal_auth_token(token) is True
|
||||
finally:
|
||||
importlib.reload(reloaded)
|
||||
|
||||
|
||||
def test_internal_auth_headers_can_carry_owner_user_id(monkeypatch):
|
||||
import app.gateway.internal_auth as internal_auth
|
||||
|
||||
monkeypatch.setenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", "shared-token")
|
||||
reloaded = importlib.reload(internal_auth)
|
||||
try:
|
||||
headers = reloaded.create_internal_auth_headers(owner_user_id="owner-1")
|
||||
|
||||
assert headers[reloaded.INTERNAL_AUTH_HEADER_NAME] == "shared-token"
|
||||
assert headers[reloaded.INTERNAL_OWNER_USER_ID_HEADER_NAME] == "owner-1"
|
||||
finally:
|
||||
monkeypatch.delenv("DEER_FLOW_INTERNAL_AUTH_TOKEN", raising=False)
|
||||
importlib.reload(reloaded)
|
||||
|
||||
@@ -44,6 +44,7 @@ class TestMakeSafeUserId:
|
||||
# Sanitized prefix plus a stable digest of the original.
|
||||
assert result.startswith("user-example-com-")
|
||||
assert len(result.rsplit("-", 1)[1]) == 16
|
||||
assert result == "user-example-com-b4c9a289323b21a0"
|
||||
assert make_safe_user_id("user@example.com") == result
|
||||
|
||||
def test_sanitized_id_passes_validation(self, paths: Paths):
|
||||
@@ -69,6 +70,40 @@ class TestUserDir:
|
||||
def test_user_dir(self, paths: Paths):
|
||||
assert paths.user_dir("alice") == paths.base_dir / "users" / "alice"
|
||||
|
||||
def test_prepare_user_dir_migrates_unique_legacy_unsafe_bucket(self, paths: Paths):
|
||||
from deerflow.config.paths import make_safe_user_id
|
||||
|
||||
raw = "user@example.com"
|
||||
safe = make_safe_user_id(raw)
|
||||
legacy_dir = paths.base_dir / "users" / "user-example-com-63a710569261a24b"
|
||||
legacy_dir.mkdir(parents=True)
|
||||
(legacy_dir / "memory.json").write_text('{"legacy": true}\n', encoding="utf-8")
|
||||
|
||||
assert paths.prepare_user_dir_for_raw_id(raw) == safe
|
||||
|
||||
current_dir = paths.user_dir(safe)
|
||||
assert current_dir.exists()
|
||||
assert not legacy_dir.exists()
|
||||
assert (current_dir / "memory.json").read_text(encoding="utf-8") == '{"legacy": true}\n'
|
||||
|
||||
def test_prepare_user_dir_never_migrates_another_users_bucket(self, paths: Paths):
|
||||
"""A different raw ID with the same sanitized prefix has a different legacy digest."""
|
||||
import hashlib
|
||||
|
||||
from deerflow.config.paths import make_safe_user_id
|
||||
|
||||
users_dir = paths.base_dir / "users"
|
||||
other_legacy = users_dir / f"a-b-{hashlib.sha1(b'a/b').hexdigest()[:16]}"
|
||||
other_legacy.mkdir(parents=True)
|
||||
arbitrary_16_hex = users_dir / "a-b-1111111111111111"
|
||||
arbitrary_16_hex.mkdir(parents=True)
|
||||
|
||||
assert paths.prepare_user_dir_for_raw_id("a.b") == make_safe_user_id("a.b")
|
||||
|
||||
assert not paths.user_dir(make_safe_user_id("a.b")).exists()
|
||||
assert other_legacy.exists()
|
||||
assert arbitrary_16_hex.exists()
|
||||
|
||||
|
||||
class TestUserMemoryFile:
|
||||
def test_user_memory_file(self, paths: Paths):
|
||||
|
||||
@@ -90,6 +90,7 @@ def test_appconfig_descriptions_retain_original_field_documentation():
|
||||
"run_events": "memory for dev",
|
||||
"checkpointer": "state-persistence checkpointer",
|
||||
"stream_bridge": "Stream bridge",
|
||||
"channel_connections": "IM channel connection",
|
||||
}
|
||||
for field_name, expected_substring in descriptions.items():
|
||||
description = AppConfig.model_fields[field_name].description or ""
|
||||
|
||||
@@ -7,7 +7,9 @@ Run from repo root:
|
||||
from __future__ import annotations
|
||||
|
||||
import yaml
|
||||
from wizard import ui as wizard_ui
|
||||
from wizard.providers import LLM_PROVIDERS, SEARCH_PROVIDERS, WEB_FETCH_PROVIDERS, LLMProvider
|
||||
from wizard.steps import channels as channels_step
|
||||
from wizard.steps import llm as llm_step
|
||||
from wizard.steps import search as search_step
|
||||
from wizard.writer import (
|
||||
@@ -327,6 +329,44 @@ class TestBuildMinimalConfig:
|
||||
assert model["when_thinking_enabled"]["extra_body"]["thinking"]["type"] == "enabled"
|
||||
assert model["when_thinking_disabled"]["extra_body"]["thinking"]["type"] == "disabled"
|
||||
|
||||
def test_can_enable_selected_channel_connections(self):
|
||||
content = build_minimal_config(
|
||||
provider_use="langchain_openai:ChatOpenAI",
|
||||
model_name="gpt-4o",
|
||||
display_name="OpenAI",
|
||||
api_key_field="api_key",
|
||||
env_var="OPENAI_API_KEY",
|
||||
channel_connection_providers=["feishu", "slack"],
|
||||
)
|
||||
|
||||
data = yaml.safe_load(content)
|
||||
channel_connections = data["channel_connections"]
|
||||
|
||||
assert channel_connections["enabled"] is True
|
||||
assert channel_connections["feishu"]["enabled"] is True
|
||||
assert channel_connections["slack"]["enabled"] is True
|
||||
assert channel_connections["telegram"]["enabled"] is False
|
||||
assert channel_connections["discord"]["enabled"] is False
|
||||
assert channel_connections["dingtalk"]["enabled"] is False
|
||||
assert channel_connections["wechat"]["enabled"] is False
|
||||
assert channel_connections["wecom"]["enabled"] is False
|
||||
|
||||
def test_channel_connections_disabled_when_no_channels_selected(self):
|
||||
content = build_minimal_config(
|
||||
provider_use="langchain_openai:ChatOpenAI",
|
||||
model_name="gpt-4o",
|
||||
display_name="OpenAI",
|
||||
api_key_field="api_key",
|
||||
env_var="OPENAI_API_KEY",
|
||||
channel_connection_providers=[],
|
||||
)
|
||||
|
||||
data = yaml.safe_load(content)
|
||||
channel_connections = data["channel_connections"]
|
||||
|
||||
assert channel_connections["enabled"] is False
|
||||
assert all(not config["enabled"] for provider, config in channel_connections.items() if provider != "enabled")
|
||||
|
||||
|
||||
class TestLLMStep:
|
||||
def test_model_selection_defaults_to_provider_default_model(self, monkeypatch):
|
||||
@@ -384,6 +424,41 @@ class TestLLMStep:
|
||||
assert result.base_url == "https://gateway.example/v1"
|
||||
|
||||
|
||||
class TestChannelsStep:
|
||||
def test_returns_selected_channel_keys(self, monkeypatch):
|
||||
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [0, 3, 6])
|
||||
|
||||
result = channels_step.run_channels_step()
|
||||
|
||||
assert result.enabled_providers == ["telegram", "feishu", "wecom"]
|
||||
|
||||
def test_empty_selection_disables_channel_connections(self, monkeypatch):
|
||||
monkeypatch.setattr(channels_step, "print_header", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(channels_step, "print_info", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(channels_step, "print_success", lambda *_args, **_kwargs: None)
|
||||
monkeypatch.setattr(channels_step, "ask_multi_choice", lambda *_args, **_kwargs: [])
|
||||
|
||||
result = channels_step.run_channels_step()
|
||||
|
||||
assert result.enabled_providers == []
|
||||
|
||||
|
||||
class TestWizardUi:
|
||||
def test_multi_choice_blank_requires_input_without_default(self, monkeypatch):
|
||||
answers = iter(["", "2"])
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt: next(answers))
|
||||
|
||||
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=None) == [1]
|
||||
|
||||
def test_multi_choice_blank_accepts_empty_default(self, monkeypatch):
|
||||
monkeypatch.setattr("builtins.input", lambda _prompt: "")
|
||||
|
||||
assert wizard_ui.ask_multi_choice("Pick", ["First", "Second"], default=[]) == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# writer.py — env file helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
"""Slack connection tests for user-owned channel bindings."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from types import ModuleType
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from app.channels.message_bus import MessageBus, OutboundMessage
|
||||
|
||||
|
||||
async def _make_repo(tmp_path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'slack.db'}", sqlite_dir=str(tmp_path))
|
||||
return ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("slack-secret"),
|
||||
)
|
||||
|
||||
|
||||
def test_slack_connect_command_binds_socket_mode_identity(tmp_path):
|
||||
import anyio
|
||||
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
async def go():
|
||||
repo = await _make_repo(tmp_path)
|
||||
state = "slack-bind-code"
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="slack",
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
channel = SlackChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "xoxb-operator", "app_token": "xapp-operator", "connection_repo": repo},
|
||||
)
|
||||
channel._web_client = MagicMock()
|
||||
|
||||
handled = await channel._bind_connection_from_connect_code(
|
||||
event={
|
||||
"user": "U123",
|
||||
"channel": "C123",
|
||||
"ts": "1710000000.000100",
|
||||
},
|
||||
team_id="T123",
|
||||
code=state,
|
||||
)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert handled is True
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "slack"
|
||||
assert connections[0]["external_account_id"] == "U123"
|
||||
assert connections[0]["workspace_id"] == "T123"
|
||||
assert connections[0]["metadata"]["channel_id"] == "C123"
|
||||
channel._web_client.chat_postMessage.assert_called_once()
|
||||
await repo.close()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_slack_send_uses_connection_bot_token_when_connection_id_is_present():
|
||||
import anyio
|
||||
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
async def go():
|
||||
repo = AsyncMock()
|
||||
repo.get_credentials.return_value = {"access_token": "xoxb-connection-token"}
|
||||
web_client = MagicMock()
|
||||
web_client_factory = MagicMock(return_value=web_client)
|
||||
channel = SlackChannel(
|
||||
bus=MessageBus(),
|
||||
config={
|
||||
"connection_repo": repo,
|
||||
"web_client_factory": web_client_factory,
|
||||
},
|
||||
)
|
||||
|
||||
msg = OutboundMessage(
|
||||
channel_name="slack",
|
||||
chat_id="C123",
|
||||
thread_id="thread-1",
|
||||
text="hello",
|
||||
connection_id="connection-1",
|
||||
)
|
||||
await channel.send(msg)
|
||||
|
||||
repo.get_credentials.assert_awaited_once_with("connection-1")
|
||||
web_client_factory.assert_called_once_with(token="xoxb-connection-token")
|
||||
web_client.chat_postMessage.assert_called_once()
|
||||
|
||||
anyio.run(go)
|
||||
|
||||
|
||||
def test_slack_http_events_mode_initializes_operator_web_client(monkeypatch):
|
||||
import anyio
|
||||
|
||||
from app.channels.slack import SlackChannel
|
||||
|
||||
class FakeWebClient:
|
||||
def __init__(self, token: str) -> None:
|
||||
self.token = token
|
||||
self.messages: list[dict] = []
|
||||
|
||||
def auth_test(self):
|
||||
return {"user_id": "B-http"}
|
||||
|
||||
def chat_postMessage(self, **kwargs):
|
||||
self.messages.append(kwargs)
|
||||
|
||||
slack_sdk = ModuleType("slack_sdk")
|
||||
slack_sdk.WebClient = FakeWebClient
|
||||
socket_mode = ModuleType("slack_sdk.socket_mode")
|
||||
socket_mode.SocketModeClient = object
|
||||
response = ModuleType("slack_sdk.socket_mode.response")
|
||||
response.SocketModeResponse = object
|
||||
monkeypatch.setitem(sys.modules, "slack_sdk", slack_sdk)
|
||||
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode", socket_mode)
|
||||
monkeypatch.setitem(sys.modules, "slack_sdk.socket_mode.response", response)
|
||||
|
||||
async def go():
|
||||
channel = SlackChannel(
|
||||
bus=MessageBus(),
|
||||
config={
|
||||
"bot_token": "xoxb-operator",
|
||||
"event_delivery": "http",
|
||||
"connection_repo": MagicMock(),
|
||||
},
|
||||
)
|
||||
|
||||
await channel.start()
|
||||
assert channel._running is True
|
||||
assert channel._web_client is not None
|
||||
assert channel._web_client.token == "xoxb-operator"
|
||||
assert channel._bot_user_id == "B-http"
|
||||
|
||||
await channel._post_connection_reply("C123", "Slack connected to DeerFlow.", "1710000000.000100")
|
||||
|
||||
assert channel._web_client.messages == [
|
||||
{
|
||||
"channel": "C123",
|
||||
"text": "Slack connected to DeerFlow.",
|
||||
"thread_ts": "1710000000.000100",
|
||||
}
|
||||
]
|
||||
await channel.stop()
|
||||
|
||||
anyio.run(go)
|
||||
@@ -164,10 +164,42 @@ def test_stream_shared_thread_passes_owner_check():
|
||||
create_or_reject.assert_awaited()
|
||||
|
||||
|
||||
def test_stream_internal_role_bypasses_owner_check():
|
||||
"""IM channels run with the internal system role on behalf of platform
|
||||
users whose threads they do not own — the owner check must not break them."""
|
||||
def test_stream_internal_role_scoped_by_owner_header():
|
||||
"""IM channels run with the internal system role on behalf of the
|
||||
connection owner named in X-DeerFlow-Owner-User-Id — the owner check is
|
||||
scoped to that owner rather than bypassed."""
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
|
||||
with _client(INTERNAL_USER) as (client, create_or_reject):
|
||||
response = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||
response = client.post(
|
||||
"/api/runs/stream",
|
||||
json=_body(THREAD_A),
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: str(USER_A.id)},
|
||||
)
|
||||
assert response.status_code == 409
|
||||
create_or_reject.assert_awaited()
|
||||
|
||||
|
||||
def test_stream_internal_role_with_foreign_owner_header_returns_404():
|
||||
"""The internal token alone must not grant access to another user's thread."""
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME
|
||||
|
||||
with _client(INTERNAL_USER) as (client, create_or_reject):
|
||||
response = client.post(
|
||||
"/api/runs/stream",
|
||||
json=_body(THREAD_A),
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: str(USER_B.id)},
|
||||
)
|
||||
assert response.status_code == 404
|
||||
create_or_reject.assert_not_awaited()
|
||||
|
||||
|
||||
def test_stream_internal_role_without_owner_header_is_scoped_to_internal_user():
|
||||
"""Without an owner header internal callers keep access to their own and
|
||||
shared/untracked threads, but not to user-owned threads."""
|
||||
with _client(INTERNAL_USER) as (client, create_or_reject):
|
||||
denied = client.post("/api/runs/stream", json=_body(THREAD_A))
|
||||
allowed = client.post("/api/runs/stream", json=_body(THREAD_SHARED))
|
||||
assert denied.status_code == 404
|
||||
assert allowed.status_code == 409
|
||||
create_or_reject.assert_awaited()
|
||||
|
||||
@@ -0,0 +1,100 @@
|
||||
"""Tests for Telegram deep-link channel connections."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.channels.message_bus import MessageBus
|
||||
from app.channels.telegram import TelegramChannel
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def repo(tmp_path: Path):
|
||||
from deerflow.persistence.channel_connections import ChannelConnectionRepository, ChannelCredentialCipher
|
||||
from deerflow.persistence.engine import close_engine, get_session_factory, init_engine
|
||||
|
||||
await init_engine("sqlite", url=f"sqlite+aiosqlite:///{tmp_path / 'telegram.db'}", sqlite_dir=str(tmp_path))
|
||||
try:
|
||||
yield ChannelConnectionRepository(
|
||||
get_session_factory(),
|
||||
cipher=ChannelCredentialCipher.from_key("telegram-secret"),
|
||||
)
|
||||
finally:
|
||||
await close_engine()
|
||||
|
||||
|
||||
def _telegram_update(*, text: str = "/start", user_id: int = 42, chat_id: int = 100, chat_type: str = "private"):
|
||||
update = MagicMock()
|
||||
update.effective_user.id = user_id
|
||||
update.effective_user.username = "alice"
|
||||
update.effective_user.full_name = "Alice Example"
|
||||
update.effective_chat.id = chat_id
|
||||
update.effective_chat.type = chat_type
|
||||
update.message.text = text
|
||||
update.message.message_id = 55
|
||||
update.message.reply_to_message = None
|
||||
update.message.reply_text = AsyncMock()
|
||||
return update
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_start_with_deep_link_state_binds_telegram_chat(repo):
|
||||
state = "telegram-bind-state"
|
||||
await repo.create_oauth_state(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="telegram",
|
||||
state=state,
|
||||
expires_at=datetime.now(UTC) + timedelta(minutes=5),
|
||||
)
|
||||
channel = TelegramChannel(
|
||||
bus=MessageBus(),
|
||||
config={"bot_token": "test-token", "connection_repo": repo},
|
||||
)
|
||||
update = _telegram_update(text=f"/start {state}")
|
||||
context = MagicMock()
|
||||
context.args = [state]
|
||||
|
||||
await channel._cmd_start(update, context)
|
||||
|
||||
connections = await repo.list_connections("deerflow-user-1")
|
||||
assert len(connections) == 1
|
||||
assert connections[0]["provider"] == "telegram"
|
||||
assert connections[0]["external_account_id"] == "42"
|
||||
assert connections[0]["external_account_name"] == "Alice Example"
|
||||
assert connections[0]["workspace_id"] == "100"
|
||||
assert connections[0]["metadata"]["chat_type"] == "private"
|
||||
update.message.reply_text.assert_awaited_once()
|
||||
assert "connected" in update.message.reply_text.await_args.args[0].lower()
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_bound_telegram_message_publishes_connection_identity(repo):
|
||||
connection = await repo.upsert_connection(
|
||||
owner_user_id="deerflow-user-1",
|
||||
provider="telegram",
|
||||
external_account_id="42",
|
||||
external_account_name="Alice Example",
|
||||
workspace_id="100",
|
||||
metadata={"chat_type": "private"},
|
||||
)
|
||||
bus = MessageBus()
|
||||
channel = TelegramChannel(
|
||||
bus=bus,
|
||||
config={"bot_token": "test-token", "connection_repo": repo},
|
||||
)
|
||||
channel._main_loop = __import__("asyncio").get_event_loop()
|
||||
channel._send_running_reply = AsyncMock()
|
||||
|
||||
await channel._on_text(_telegram_update(text="hello"), None)
|
||||
inbound = await bus.get_inbound()
|
||||
|
||||
assert inbound.connection_id == connection["id"]
|
||||
assert inbound.owner_user_id == "deerflow-user-1"
|
||||
assert inbound.workspace_id == "100"
|
||||
assert inbound.user_id == "42"
|
||||
assert inbound.chat_id == "100"
|
||||
assert inbound.text == "hello"
|
||||
@@ -137,6 +137,19 @@ class TestThreadMetaRepository:
|
||||
async def test_update_metadata_nonexistent_is_noop(self, repo):
|
||||
await repo.update_metadata("nonexistent", {"k": "v"}) # should not raise
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_update_owner_with_bypass_moves_row(self, repo):
|
||||
await repo.create("t1", user_id="default", metadata={"source": "channel"})
|
||||
await repo.update_owner("t1", "owner-1", user_id=None)
|
||||
|
||||
owner_row = await repo.get("t1", user_id="owner-1")
|
||||
default_row = await repo.get("t1", user_id="default")
|
||||
|
||||
assert owner_row is not None
|
||||
assert owner_row["user_id"] == "owner-1"
|
||||
assert owner_row["metadata"] == {"source": "channel"}
|
||||
assert default_row is None
|
||||
|
||||
# --- search with metadata filter (SQL push-down) ---
|
||||
|
||||
@pytest.mark.anyio
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import re
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
@@ -218,6 +219,37 @@ def test_create_thread_returns_iso_timestamps() -> None:
|
||||
assert body["created_at"] == body["updated_at"]
|
||||
|
||||
|
||||
def test_internal_owner_header_assigns_thread_to_owner() -> None:
|
||||
import asyncio
|
||||
|
||||
from app.gateway.internal_auth import INTERNAL_OWNER_USER_ID_HEADER_NAME, INTERNAL_SYSTEM_ROLE
|
||||
|
||||
store = InMemoryStore()
|
||||
checkpointer = InMemorySaver()
|
||||
thread_store = MemoryThreadMetaStore(store)
|
||||
request = SimpleNamespace(
|
||||
headers={INTERNAL_OWNER_USER_ID_HEADER_NAME: "owner-1"},
|
||||
state=SimpleNamespace(user=SimpleNamespace(id="default", system_role=INTERNAL_SYSTEM_ROLE)),
|
||||
app=SimpleNamespace(state=SimpleNamespace(checkpointer=checkpointer, thread_store=thread_store)),
|
||||
)
|
||||
|
||||
async def _scenario():
|
||||
response = await threads.create_thread(
|
||||
threads.ThreadCreateRequest(thread_id="channel-thread", metadata={}),
|
||||
request,
|
||||
)
|
||||
owner_row = await thread_store.get("channel-thread", user_id="owner-1")
|
||||
internal_row = await thread_store.get("channel-thread", user_id="default")
|
||||
return response, owner_row, internal_row
|
||||
|
||||
response, owner_row, internal_row = asyncio.run(_scenario())
|
||||
|
||||
assert response.thread_id == "channel-thread"
|
||||
assert owner_row is not None
|
||||
assert owner_row["user_id"] == "owner-1"
|
||||
assert internal_row is None
|
||||
|
||||
|
||||
def test_get_thread_returns_iso_for_legacy_unix_record() -> None:
|
||||
"""A thread record written by older versions stores ``time.time()``
|
||||
floats. ``get_thread`` must transparently surface them as ISO so the
|
||||
|
||||
Reference in New Issue
Block a user